最近刚认识KAN,上github找了源码,里面正好有jupyter给我玩,结果初始化就给我干蒙了,源码如下,问题出在model那句但不知道如何解决.哪位Da Lao拔冗帮我看看怎么修改,感谢感谢!
from kan import *
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape
# plot KAN at initialization
model(dataset['train_input']);
model.plot(beta=100)
报错内容:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[4], line 2
1 # plot KAN at initialization
----> 2 model(dataset['train_input']);
3 model.plot(beta=100)
File F:\anaconda\Lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File F:\anaconda\Lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File E:\PWmat\ML2024\KAN\pykan-master\kan\KAN.py:311, in KAN.forward(self, x)
307 self.acts.append(x) # acts shape: (batch, width[l])
309 for l in range(self.depth):
--> 311 x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
313 if self.symbolic_enabled == True:
314 x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)
File F:\anaconda\Lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File F:\anaconda\Lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File E:\PWmat\ML2024\KAN\pykan-master\kan\KANLayer.py:176, in KANLayer.forward(self, x)
174 y = y.permute(1, 0) # shape (batch, size)
175 postspline = y.clone().reshape(batch, self.out_dim, self.in_dim)
--> 176 y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y
177 y = self.mask[None, :] * y
178 postacts = y.clone().reshape(batch, self.out_dim, self.in_dim)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!