我的数据集已经划分好了训练、验证和测试,我怎么用训练集进行训练验证集进行验证测试集进行测试呢
我发现我处理后的图像文件夹是空的
# 图像文件夹路径
train_folder = "/kaggle/input/TF/train"
valid_folder = "/kaggle/input/TF/valid"
test_folder = "/kaggle/input/TF/test"
# 定义图像预处理操作
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(degrees=15), # 随机旋转
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色扭曲
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
# 加载RGB图像并进行预处理
def load_rgb_image(folder, filename):
image_path = os.path.join(folder, filename)
image = Image.open(image_path)
transformed_image = transform(image)
return transformed_image.unsqueeze(0) # 添加批处理维度
def load_rgb_image(folder, filename):
image_path = os.path.join(folder, filename)
if os.path.exists(image_path): # 检查图像文件是否存在
image = Image.open(image_path)
transformed_image = transform(image)
return transformed_image.unsqueeze(0)
else:
print('hello')
def apply_dct_channel(channel, threshold):
dct_channel = torch.fft.fftn(channel, dim=(-2, -1)) # 离散余弦变换
dct_channel_filtered = dct_channel * (torch.abs(dct_channel) > threshold) # 进行阈值过滤
idct_channel = torch.fft.ifftn(dct_channel_filtered, dim=(-2, -1)) # 离散余弦逆变换
return idct_channel.real.to(channel.device) # 返回实部,并将结果放回原设备
def apply_dct(image, threshold):
r_channel = image[:, 0, :, :]
g_channel = image[:, 1, :, :]
b_channel = image[:, 2, :, :]
r_processed = apply_dct_channel(r_channel, threshold)
g_processed = apply_dct_channel(g_channel, threshold)
b_processed = apply_dct_channel(b_channel, threshold)
processed_image = torch.stack([r_processed, g_processed, b_processed], dim=1)
return processed_image
def apply_srm_channel(channel, threshold):
channel_np = channel.numpy() # Convert Torch tensor to NumPy array
dct_channel = dct(dct(channel_np, axis=0, norm='ortho'), axis=1, norm='ortho')
dct_channel_filtered = dct_channel * (np.abs(dct_channel) > threshold)
idct_channel = idct(idct(dct_channel_filtered, axis=0, norm='ortho'), axis=1, norm='ortho')
return torch.from_numpy(idct_channel) # Convert NumPy array back to Torch tensor
def apply_srm(image, threshold):
r_channel = image[:, :, 0]
g_channel = image[:, :, 1]
b_channel = image[:, :, 2]
r_processed = apply_srm_channel(r_channel, threshold)
g_processed = apply_srm_channel(g_channel, threshold)
b_processed = apply_srm_channel(b_channel, threshold)
processed_image = torch.stack([r_processed, g_processed, b_processed], dim=-1)
return processed_image
# 创建EfficientNet模型
def create_efficientnet_model(input_shape, num_classes):
model = efficientnet_b0(pretrained=True)
model._fc = nn.Linear(1280, num_classes) # 替换最后一层全连接层
return model
# 获取文件夹中的所有图像文件名
def get_image_filenames(folder):
image_filenames = []
for filename in os.listdir(folder):
if filename.endswith(".jpg") or filename.endswith(".png"):
image_filenames.append(filename)
return image_filenames
# 获取训练集、验证集和测试集的图像文件名
train_filenames = get_image_filenames(train_folder)
valid_filenames = get_image_filenames(valid_folder)
test_filenames = get_image_filenames(test_folder)
# 加载RGB图像
train_rgb_images = []
for i, filename in enumerate(train_filenames):
rgb_image = load_rgb_image(train_folder, filename)
train_rgb_images.append(rgb_image)
print(f"Loading train RGB images: {i+1}/{len(train_filenames)}", end="\r")
valid_rgb_images = []
for i, filename in enumerate(valid_filenames):
rgb_image = load_rgb_image(valid_folder, filename)
valid_rgb_images.append(rgb_image)
print(f"Loading valid RGB images: {i+1}/{len(valid_filenames)}", end="\r")
test_rgb_images = []
for i, filename in enumerate(test_filenames):
rgb_image = load_rgb_image(test_folder, filename)
test_rgb_images.append(rgb_image)
print(f"Loading test RGB images: {i+1}/{len(test_filenames)}", end="\r")
# 执行DCT处理
train_dct_images = []
for i, rgb_image in enumerate(train_rgb_images):
dct_image = apply_dct(rgb_image, 20)
train_dct_images.append(dct_image)
print(f"Applying DCT to train images: {i+1}/{len(train_rgb_images)}", end="\r")
valid_dct_images = []
for i, rgb_image in enumerate(valid_rgb_images):
dct_image = apply_dct(rgb_image, 20)
valid_dct_images.append(dct_image)
print(f"Applying DCT to valid images: {i+1}/{len(valid_rgb_images)}", end="\r")
test_dct_images = []
for i, rgb_image in enumerate(test_rgb_images):
dct_image = apply_dct(rgb_image, 20)
test_dct_images.append(dct_image)
print(f"Applying DCT to test images: {i+1}/{len(test_rgb_images)}", end="\r")
# 执行SRM处理
train_srm_images = []
for i, rgb_image in enumerate(train_rgb_images):
srm_image = apply_srm(rgb_image, 20)
train_srm_images.append(srm_image)
print(f"Applying SRM to train images: {i+1}/{len(train_rgb_images)}", end="\r")
#---------------
for i, rgb_image in enumerate(train_rgb_images):
dct_image = apply_dct(rgb_image, 20)
if dct_image is not None: # 检查图像是否为空
train_dct_images.append(dct_image)
print(f"Applying DCT to train images: {i+1}/{len(train_rgb_images)}", end="\r")
valid_srm_images = []
for i, rgb_image in enumerate(valid_rgb_images):
srm_image = apply_srm(rgb_image, 20)
valid_srm_images.append(srm_image)
print(f"Applying SRM to valid images: {i+1}/{len(valid_rgb_images)}", end="\r")
test_srm_images = []
for i, rgb_image in enumerate(test_rgb_images):
srm_image = apply_srm(rgb_image, 20)
test_srm_images.append(srm_image)
print(f"Applying SRM to test images: {i+1}/{len(test_rgb_images)}", end="\r")
# 将图像和标签转换为模型输入的张量格式
train_images = torch.cat(train_rgb_images, dim=0)
valid_images = torch.cat(valid_rgb_images, dim=0)
test_images = torch.cat(test_rgb_images, dim=0)
train_labels = torch.tensor([0] * len(train_filenames))
valid_labels = torch.tensor([0] * len(valid_filenames))
test_labels = torch.tensor([0] * len(test_filenames))
#--------------------------------------------------
报错如下:---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[26], line 2
1 # 将图像和标签转换为模型输入的张量格式
----> 2 train_images = torch.cat(train_rgb_images, dim=0)
3 valid_images = torch.cat(valid_rgb_images, dim=0)
4 test_images = torch.cat(test_rgb_images, dim=0)
RuntimeError: torch.cat(): expected a non-empty list of Tensors
求解,我怀疑和自己文件路径有关系,我数据集有train、valid和test,三者中分别含有real、fake两个子文件夹,分别包含对应的图像