以下内容包括三部分:1.报错内容 2.train.py 3.Model.py
我在下采样之后添加了出现了以下报错,我尝试在foward中将张量移到同一设备,没有解决该问题,请教一下帮忙看看,谢谢
1.报错内容:
Traceback (most recent call last):
File "train.py", line 97, in <module>
cd_preds = model(batch_img1, batch_img2)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/data/private/lizihao/HWTNet/models/Models.py", line 325, in forward
x0_0A = self.h0(x0_0A) # 添加半小波变换
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/data/private/lizihao/HWTNet/models/Models.py", line 248, in forward
out = torch.cat([wavelet_path, identity_path], dim=1)
RuntimeError: All input tensors must be on the same device. Received cpu and cuda:0