2301_76839626 2024-05-09 22:16 采纳率: 0%
浏览 186

KAN:RuntimeError: Expected all tensors to be on the same device, but found at least two devices

img

最近刚认识KAN,上github找了源码,里面正好有jupyter给我玩,结果初始化就给我干蒙了,源码如下,问题出在model那句但不知道如何解决.哪位Da Lao拔冗帮我看看怎么修改,感谢感谢!

from kan import *
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape
# plot KAN at initialization
model(dataset['train_input']);
model.plot(beta=100)

报错内容:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[4], line 2
      1 # plot KAN at initialization
----> 2 model(dataset['train_input']);
      3 model.plot(beta=100)

File F:\anaconda\Lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File F:\anaconda\Lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File E:\PWmat\ML2024\KAN\pykan-master\kan\KAN.py:311, in KAN.forward(self, x)
    307 self.acts.append(x)  # acts shape: (batch, width[l])
    309 for l in range(self.depth):
--> 311     x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
    313     if self.symbolic_enabled == True:
    314         x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)

File F:\anaconda\Lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File F:\anaconda\Lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File E:\PWmat\ML2024\KAN\pykan-master\kan\KANLayer.py:176, in KANLayer.forward(self, x)
    174 y = y.permute(1, 0)  # shape (batch, size)
    175 postspline = y.clone().reshape(batch, self.out_dim, self.in_dim)
--> 176 y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y
    177 y = self.mask[None, :] * y
    178 postacts = y.clone().reshape(batch, self.out_dim, self.in_dim)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

  • 写回答

3条回答 默认 最新

  • 专家-皆知 2024-05-09 22:16
    关注

    引自免费微信小程序:皆我百晓生

    确保所有张量位于同一设备上,可以使用torch.device迁移到所需设备。

    评论

报告相同问题?

问题事件

  • 修改了问题 5月10日
  • 创建了问题 5月9日

悬赏问题

  • ¥15 35114 SVAC视频验签的问题
  • ¥15 impedancepy
  • ¥15 在虚拟机环境下完成以下,要求截图!
  • ¥15 求往届大挑得奖作品(ppt…)
  • ¥15 如何在vue.config.js中读取到public文件夹下window.APP_CONFIG.API_BASE_URL的值
  • ¥50 浦育平台scratch图形化编程
  • ¥20 求这个的原理图 只要原理图
  • ¥15 vue2项目中,如何配置环境,可以在打完包之后修改请求的服务器地址
  • ¥20 微信的店铺小程序如何修改背景图
  • ¥15 UE5.1局部变量对蓝图不可见