问题遇到的现象和发生背景
用的pytorch1.2,python3.7做一个目标检测的模型,上层函数里的参数和实际调用的参数一致,但是在训练时pytorch里面的moudle.py一直报错:
Traceback (most recent call last):
File "train_TextGraph.py", line 237, in <module>
main()
File "train_TextGraph.py", line 218, in main
train(model, train_loader, criterion, scheduler, optimizer, epoch, logger)
File "train_TextGraph.py", line 75, in train
output, gcn_data = model(img, gt_roi, to_device)
File "/home/gpu/anaconda3/envs/DRRG/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "/home/gpu/Desktop/DRRG-master/network/textnet.py", line 125, in forward
gcn_pred = self.gcn_model(feat_batch, adj_batch, h1id_batch)
File "/home/gpu/anaconda3/envs/DRRG/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() takes 3 positional arguments but 4 were given
根据指示找到该函数,发现这个函数只有两句代码,请问这是什么原因,该咋解决?
def forward(self, *input):
raise NotImplementedError