inp_img是输入的图片,padw是想得到图片的宽-输入图片的宽,应该是宽需要填充的列数,padh是需要填充的行数,
下面是这部分的代码,我理解的对吗,中间括号里的(0,0,padw,padh)每个位置各代表的是什么?
class DataLoaderTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(DataLoaderTrain, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
w,h = tar_img.size
padw = ps-w if w<ps else 0
padh = ps-h if h<ps else 0
# Reflect Pad in case image is smaller than patch_size
if padw!=0 or padh!=0:
inp_img = TF.pad(inp_img, (0,0,padw,padh), padding_mode='reflect')
tar_img = TF.pad(tar_img, (0,0,padw,padh), padding_mode='reflect')