出现的报错信息
Traceback (most recent call last):
File "rain.py", line 263, in <module>
main()
File "train.py", line 260, in main
train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
File "train.py", line 144, in train
train_loss, train_loss_template, train_loss_source = train_one_epoch(args, model, train_loader, optimizer)
File "train.py", line 104, in train_one_epoch
masked_template, masked_source, predicted_mask_template, predicted_mask_source= model(template, source)
ValueError: not enough values to unpack (expected 4, got 2)
Process finished with exit code 1
报错处代码
for i, data in enumerate(tqdm(train_loader)):
template, source, igt, gt_mask_template, gt_mask_source = data
template = template.to(args.device)
source = source.to(args.device)
igt = igt.to(args.device) # [source] = [igt]*[template]
gt_mask_template = gt_mask_template.to(args.device)
gt_mask_source = gt_mask_source.to(args.device)
masked_template, masked_source, predicted_mask_template, predicted_mask_source= model(template, source) #报错
请问是哪里出问题了?