2301_82090566 2026-01-30 11:38 采纳率: 0%
浏览 2

尝试使用DNCNN时遇到的问题

我在运行DNCNN的dataset时train.h5文件没有被创建出来,并且dataset没有报错

import os
import os.path
import numpy as np
import random
import h5py
import torch
import cv2
import glob
import torch.utils.data as udata
from utils import data_augmentation

def normalize(data):
    return data/255.

def Im2Patch(img, win, stride=1):
    k = 0
    endc = img.shape[0]
    endw = img.shape[1]
    endh = img.shape[2]
    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
    TotalPatNum = patch.shape[1] * patch.shape[2]
    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
    for i in range(win):
        for j in range(win):
            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
            k = k + 1
    return Y.reshape([endc, win, win, TotalPatNum])

def prepare_data(data_path, patch_size, stride, aug_times=1):
    # train
    print('process training data')
    scales = [1, 0.9, 0.8, 0.7]
    files = glob.glob(os.path.join(data_path, 'train', '*.png'))
    files.sort()
    h5f = h5py.File('train.h5', 'w')
    train_num = 0
    for i in range(len(files)):
        img = cv2.imread(files[i])
        h, w, c = img.shape
        for k in range(len(scales)):
            Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
            Img = np.expand_dims(Img[:,:,0].copy(), 0)
            Img = np.float32(normalize(Img))
            patches = Im2Patch(Img, win=patch_size, stride=stride)
            print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3]*aug_times))
            for n in range(patches.shape[3]):
                data = patches[:,:,:,n].copy()
                h5f.create_dataset(str(train_num), data=data)
                train_num += 1
                for m in range(aug_times-1):
                    data_aug = data_augmentation(data, np.random.randint(1,8))
                    h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
                    train_num += 1
    h5f.close()
    # val
    print('\nprocess validation data')
    files.clear()
    files = glob.glob(os.path.join(data_path, 'Set12', '*.png'))
    files.sort()
    h5f = h5py.File('val.h5', 'w')
    val_num = 0
    for i in range(len(files)):
        print("file: %s" % files[i])
        img = cv2.imread(files[i])
        img = np.expand_dims(img[:,:,0], 0)
        img = np.float32(normalize(img))
        h5f.create_dataset(str(val_num), data=img)
        val_num += 1
    h5f.close()
    print('training set, # samples %d\n' % train_num)
    print('val set, # samples %d\n' % val_num)

class Dataset(udata.Dataset):
    def __init__(self, train=True):
        super(Dataset, self).__init__()
        self.train = train
        if self.train:
            h5f = h5py.File('train.h5', 'r')
        else:
            h5f = h5py.File('val.h5', 'r')
        self.keys = list(h5f.keys())
        random.shuffle(self.keys)
        h5f.close()
    def __len__(self):
        return len(self.keys)
    def __getitem__(self, index):
        if self.train:
            h5f = h5py.File('train.h5', 'r')
        else:
            h5f = h5py.File('val.h5', 'r')
        key = self.keys[index]
        data = np.array(h5f[key])
        h5f.close()
        return torch.Tensor(data)


  • 写回答

2条回答 默认 最新

  • 檀越@新空间 2026-01-30 11:39
    关注

    上午好☀️☀️☀️️
    本答案参考通义千问

    你遇到的问题是 在运行DNCNN的dataset时,train.h5文件没有被创建出来,并且 dataset没有报错。这表明代码中可能存在一些潜在问题,比如路径错误、权限问题、或数据处理过程中某些逻辑未执行。


    一、可能原因分析

    1. 文件路径错误

      • h5py.File('train.h5', 'w') 会尝试在当前工作目录下创建文件。如果当前目录没有写入权限,或者程序运行的环境(如Jupyter Notebook、服务器等)中该路径不可写,会导致文件未被创建。
    2. 数据处理未执行

      • prepare_data 函数中,如果 files 为空(即没有找到任何 .png 文件),则不会进入循环,导致 train.h5 未被创建。
    3. 文件未正确关闭

      • 虽然你使用了 h5f.close(),但在某些情况下,文件可能没有被正确保存或写入。
    4. 异常处理缺失

      • 没有捕获异常,导致某些错误(如路径错误、文件打开失败)未被提示。

    二、解决方案

    1. 确认数据路径是否正确

    • 检查 data_path 是否指向正确的训练图片目录(例如:'./data/train/')。
    • 确保目录中存在 .png 图像文件。
    • 可以添加打印语句验证:
    print("Files found:", files)
    

    2. 添加异常处理和日志输出

    • 使用 try-except 块来捕获可能的错误,并打印详细信息。
    def prepare_data(data_path, patch_size, stride, aug_times=1):
        try:
            print('process training data')
            scales = [1, 0.9, 0.8, 0.7]
            files = glob.glob(os.path.join(data_path, 'train', '*.png'))
            if not files:
                print("No training images found in the specified path.")
                return
            files.sort()
            h5f = h5py.File('train.h5', 'w')
            train_num = 0
            for i in range(len(files)):
                img = cv2.imread(files[i])
                if img is None:
                    print(f"Failed to read image: {files[i]}")
                    continue
                h, w, c = img.shape
                for k in range(len(scales)):
                    Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
                    Img = np.expand_dims(Img[:,:,0].copy(), 0)
                    Img = np.float32(normalize(Img))
                    patches = Im2Patch(Img, win=patch_size, stride=stride)
                    print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3]*aug_times))
                    for n in range(patches.shape[3]):
                        data = patches[:,:,:,n].copy()
                        h5f.create_dataset(str(train_num), data=data)
                        train_num += 1
                        for m in range(aug_times-1):
                            data_aug = data_augmentation(data, np.random.randint(1,8))
                            h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
                            train_num += 1
            h5f.close()
            # val
            print('\nprocess validation data')
            files.clear()
            files = glob.glob(os.path.join(data_path, 'Set12', '*.png'))
            if not files:
                print("No validation images found in the specified path.")
                return
            files.sort()
            h5f = h5py.File('val.h5', 'w')
            val_num = 0
            for i in range(len(files)):
                print("file: %s" % files[i])
                img = cv2.imread(files[i])
                if img is None:
                    print(f"Failed to read image: {files[i]}")
                    continue
                img = np.expand_dims(img[:,:,0], 0)
                img = np.float32(normalize(img))
                h5f.create_dataset(str(val_num), data=img)
                val_num += 1
            h5f.close()
            print('training set, # samples %d\n' % train_num)
            print('val set, # samples %d\n' % val_num)
        except Exception as e:
            print("Error occurred during data preparation:", e)
    

    3. 确保文件写入权限

    • 确保你有权限在当前工作目录中创建文件。
    • 如果你在服务器上运行,可以尝试将文件保存到特定路径(如 ./data/ 目录):
    h5f = h5py.File(os.path.join(data_path, 'train.h5'), 'w')
    

    4. 修改 prepare_data 的调用方式

    确保你正确地调用了 prepare_data 函数,例如:

    if __name__ == '__main__':
        data_path = './data/'
        prepare_data(data_path, patch_size=40, stride=10, aug_times=1)
    

    三、修改后的完整代码(关键部分)

    import os
    import os.path
    import numpy as np
    import random
    import h5py
    import torch
    import cv2
    import glob
    import torch.utils.data as udata
    from utils import data_augmentation
    
    def normalize(data):
        return data / 255.
    
    def Im2Patch(img, win, stride=1):
        k = 0
        endc = img.shape[0]
        endw = img.shape[1]
        endh = img.shape[2]
        patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
        TotalPatNum = patch.shape[1] * patch.shape[2]
        Y = np.zeros([endc, win*win, TotalPatNum], np.float32)
        for i in range(win):
            for j in range(win):
                patch = img[:, i:endw-win+i+1:stride, j:endh-win+j+1:stride]
                Y[:, k, :] = np.array(patch[:]).reshape(endc, TotalPatNum)
                k += 1
        return Y.reshape([endc, win, win, TotalPatNum])
    
    def prepare_data(data_path, patch_size, stride, aug_times=1):
        try:
            print('process training data')
            scales = [1, 0.9, 0.8, 0.7]
            files = glob.glob(os.path.join(data_path, 'train', '*.png'))
            if not files:
                print("No training images found in the specified path.")
                return
            files.sort()
            h5f = h5py.File(os.path.join(data_path, 'train.h5'), 'w')
            train_num = 0
            for i in range(len(files)):
                img = cv2.imread(files[i])
                if img is None:
                    print(f"Failed to read image: {files[i]}")
                    continue
                h, w, c = img.shape
                for k in range(len(scales)):
                    Img = cv2.resize(img, (int(h * scales[k]), int(w * scales[k])), interpolation=cv2.INTER_CUBIC)
                    Img = np.expand_dims(Img[:, :, 0].copy(), 0)
                    Img = np.float32(normalize(Img))
                    patches = Im2Patch(Img, win=patch_size, stride=stride)
                    print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3] * aug_times))
                    for n in range(patches.shape[3]):
                        data = patches[:, :, :, n].copy()
                        h5f.create_dataset(str(train_num), data=data)
                        train_num += 1
                        for m in range(aug_times - 1):
                            data_aug = data_augmentation(data, np.random.randint(1, 8))
                            h5f.create_dataset(str(train_num) + "_aug_%d" % (m + 1), data=data_aug)
                            train_num += 1
            h5f.close()
            # val
            print('\nprocess validation data')
            files.clear()
            files = glob.glob(os.path.join(data_path, 'Set12', '*.png'))
            if not files:
                print("No validation images found in the specified path.")
                return
            files.sort()
            h5f = h5py.File(os.path.join(data_path, 'val.h5'), 'w')
            val_num = 0
            for i in range(len(files)):
                print("file: %s" % files[i])
                img = cv2.imread(files[i])
                if img is None:
                    print(f"Failed to read image: {files[i]}")
                    continue
                img = np.expand_dims(img[:, :, 0], 0)
                img = np.float32(normalize(img))
                h5f.create_dataset(str(val_num), data=img)
                val_num += 1
            h5f.close()
            print('training set, # samples %d\n' % train_num)
            print('val set, # samples %d\n' % val_num)
        except Exception as e:
            print("Error occurred during data preparation:", e)
    
    class Dataset(udata.Dataset):
        def __init__(self, train=True):
            super(Dataset, self).__init__()
            self.train = train
            if self.train:
                h5f = h5py.File(os.path.join(data_path, 'train.h5'), 'r')
            else:
                h5f = h5py.File(os.path.join(data_path, 'val.h5'), 'r')
            self.keys = list(h5f.keys())
            random.shuffle(self.keys)
            h5f.close()
    
        def __len__(self):
            return len(self.keys)
    
        def __getitem__(self, index):
            if self.train:
                h5f = h5py.File(os.path.join(data_path, 'train.h5'), 'r')
            else:
                h5f = h5py.File(os.path.join(data_path, 'val.h5'), 'r')
            key = self.keys[index]
            data = np.array(h5f[key])
            h5f.close()
            return torch.Tensor(data)
    
    if __name__ == '__main__':
        data_path = './data/'
        prepare_data(data_path, patch_size=40, stride=10, aug_times=1)
    

    四、总结

    重点问题解决步骤如下:

    1. 检查文件路径是否正确,并确认 data_path 下是否有图像文件。
    2. 添加异常处理和日志输出,以便捕捉错误。
    3. 确保文件写入权限,避免因权限问题导致文件未生成。
    4. 修改 prepare_dataDataset 中的文件路径,使用绝对路径提高稳定性。

    如果你能提供 data_path 的实际路径和文件结构,我可以进一步帮你调试。

    评论

报告相同问题?

问题事件

  • 创建了问题 1月30日