import os
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor()
])
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.name = os.listdir(os.path.join(path, 'imangesPng'))
def __len__(self):
return len(self.name)
def __getitem__(self, index):
tag_name = self.name[index]
tag_path = os.path.join(self.path, 'imangesPng', tag_name)
img_path = os.path.join(self.path, 'imangesJpeg', tag_name.replace('png', 'jpg'))
tag_imange = keep_image_size_open(tag_path, (1024, 1024))
imanges = keep_image_size_open(img_path, (1024, 1024))
return transform(imanges), transform(tag_imange)
if __name__ == '__main__':
data = MyDataset(r'E:\pycharm\XrayFemurDivision\unet_model\Imanges')
print(data[0][0].shape)
print(data[0][1].shape)
尝试过将目录的连接符改为\