在运行代码的过程中遇到报错 TypeError: forward() takes 2 positional arguments but 3 were given
我的代码:
```python
def _call_impl(self, *input, **kwargs):
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if len(full_backward_hooks) > 0:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result =self.forward(*input , **kwargs) #报错
for hook in itertools.chain(
_global_forward_hooks.values(),
self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
报错信息:
Traceback (most recent call last):
File "train.py", line 237, in <module>
main()
File "train.py", line 234, in main
train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
File "train.py", line 122, in train
train_loss = train_one_epoch(args, model, train_loader, optimizer)
File "train.py", line 90, in train_one_epoch
masked_template, predicted_mask = model(template, source)
File "module.py", line 889, in _call_impl
t =self.forward(*input , **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
Process finished with exit code 1