最近在复现一篇论文,这个是github链接:https://github.com/swz30/MPRNet,我运行的是去雨部分Deraining文件夹里的代码:
我修改了training.yml文件的两个数值,其他部分的代码没有修改过
training.yml文件代码:
GPU: [0,1,2,3]
VERBOSE: True
MODEL:
MODE: 'Deraining'
SESSION: 'MPRNet'
# Optimization arguments.
OPTIM:
BATCH_SIZE: 2 #原本是16,改成了2
NUM_EPOCHS: 250
# NEPOCH_DECAY: [10]
LR_INITIAL: 2e-4
LR_MIN: 1e-6
# BETA1: 0.9
TRAINING:
VAL_AFTER_EVERY: 5
RESUME: False
TRAIN_PS: 128 #原本是256,改成了128
VAL_PS: 128
TRAIN_DIR: './Datasets/train' # path to training data
VAL_DIR: './Datasets/test/Rain100L' # path to validation data
SAVE_DIR: './checkpoints' # path to save models and images
# SAVE_IMAGES: False
在训练阶段,跑到几个几十个epoch时报错,错误代码如下:
Traceback (most recent call last):
File "train.py", line 109, in <module>
for i, data in enumerate(tqdm(train_loader), 0):
File "/home/min1/anaconda3/envs/bwq/lib/python3.7/site-packages/tqdm/std.py", line 1185, in __iter__
for obj in iterable:
File "/home/min1/anaconda3/envs/bwq/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 582, in __next__
return self._process_next_batch(batch)
File "/home/min1/anaconda3/envs/bwq/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
RuntimeError: Traceback (most recent call last):
File "/home/min1/anaconda3/envs/bwq/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/min1/anaconda3/envs/bwq/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 68, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/min1/anaconda3/envs/bwq/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 68, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/min1/anaconda3/envs/bwq/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 43, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 128 and 126 in dimension 3 at /opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/TH/generic/THTensor.cpp:711
这种错误遇到了很多次,第一次是跑了9个epoch,第二次是22个epoch时,第三次31个epoch,今天是第四次跑了60个epoch,报错的形式是一样的,最后一行的Got 128 and 126 in dimension 3这里的数字不太一样。
这个报错里写了很多有关环境的语句,环境其实是我另一位同学的,但是他跑的也是这个链接的去模糊的代码,环境我们应该是相同的
请问是数据集还是哪里有错误呢?应该怎么改?