2401_83393291 2024-04-23 20:48 采纳率: 64.3%
浏览 10

求带yolov5训练自己的数据集

img

img

img

img


这些是我运行train.py 时的报错以及我的data里面的yaml文件 可不可以稍微详细的帮我解决一下问题 最好是可以直接呈现纠错后的代码

  • 写回答

4条回答 默认 最新

  • 码农阿豪@新空间代码工作室 Java领域优质创作者 2024-04-23 21:04
    关注
    让阿豪来帮你解答,本回答参考chatgpt3.5编写提供,如果还有疑问可以评论或留言
    问题描述: 在运行train.py时,出现了以下报错信息:
    AttributeError: module 'torch' has no attribute 'nn'
    

    同时,提供了data文件夹下的两个yaml文件,其中的内容如下:

    # dataset.yaml
    dataset:
      name: 'MNIST'
      root: '/path/to/MNIST'
    
    # model.yaml
    model:
      name: 'MLP'
      input_size: 28*28
      hidden_sizes: [128, 64]
      output_size: 10
    

    解决方案: 这个报错提示表明在torch模块中找不到nn的属性,根据 Torch 的官方文档介绍,可以使用 torch.nn 模块来进行神经网络的搭建,因此,可能是导入 torch 后没有正确引用 nn 模块。为了解决这个问题,我们需要确保在train.py文件中导入了torch.nn模块,我们可以在 train.py 的开头加上以下代码:

    import torch.nn as nn
    

    这里给出示例代码:

    import torch
    import torch.nn as nn
    # ... 以下是其他代码 ...
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(28*28, 128)
            self.fc2 = nn.Linear(128, 64)
            self.fc3 = nn.Linear(64, 10)
        
        def forward(self, x):
            x = x.view(-1, 28*28)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    # ... 以下是其他代码 ...
    

    另外,提供的两个yaml文件是用来配置数据集和模型的参数的。在train.py中可能需要更改读取路径,我们可以将数据集的这部分代码替换为以下代码:

    import yaml
    # ...
    with open('data/dataset.yaml') as f:
        dataset_info = yaml.load(f, Loader=yaml.FullLoader)
    # ...
    

    并将模型的这部分代码替换为以下代码:

    import yaml
    # ...
    with open('data/model.yaml') as f:
        model_info = yaml.load(f, Loader=yaml.FullLoader)
    model_name = model_info['model']['name']
    hidden_sizes = model_info['model']['hidden_sizes']
    input_size = model_info['model']['input_size']
    output_size = model_info['model']['output_size']
    # ...
    

    最终纠错后的代码参考如下:

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    import torchvision.datasets as dset
    import torchvision.transforms as transforms
    import yaml
    # 加载数据集
    with open('data/dataset.yaml') as f:
        dataset_info = yaml.load(f, Loader=yaml.FullLoader)
    mnist_train = dset.MNIST(root=dataset_info['dataset']['root'], train=True, transform=transforms.ToTensor(), download=True)
    mnist_test = dset.MNIST(root=dataset_info['dataset']['root'], train=False, transform=transforms.ToTensor(), download=True)
    # 定义模型
    with open('data/model.yaml') as f:
        model_info = yaml.load(f, Loader=yaml.FullLoader)
    model_name = model_info['model']['name']
    hidden_sizes = model_info['model']['hidden_sizes']
    input_size = model_info['model']['input_size']
    output_size = model_info['model']['output_size']
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(input_size, hidden_sizes[0])
            self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
            self.fc3 = nn.Linear(hidden_sizes[1], output_size)
        
        def forward(self, x):
            x = x.view(-1, input_size)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    net = Net()
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    # 训练模型
    for epoch in range(10):
        running_loss = 0.0
        for i, data in enumerate(mnist_train, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 1000 == 999:  # 每1000个batch打印一下loss和准确率
                print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/1000))
                running_loss = 0.0
    print('Finished Training')
    
    评论

报告相同问题?

问题事件

  • 创建了问题 4月23日

悬赏问题

  • ¥30 STM32 INMP441无法读取数据
  • ¥100 求汇川机器人IRCB300控制器和示教器同版本升级固件文件升级包
  • ¥15 用visualstudio2022创建vue项目后无法启动
  • ¥15 x趋于0时tanx-sinx极限可以拆开算吗
  • ¥500 把面具戴到人脸上,请大家贡献智慧
  • ¥15 任意一个散点图自己下载其js脚本文件并做成独立的案例页面,不要作在线的,要离线状态。
  • ¥15 各位 帮我看看如何写代码,打出来的图形要和如下图呈现的一样,急
  • ¥30 c#打开word开启修订并实时显示批注
  • ¥15 如何解决ldsc的这条报错/index error
  • ¥15 VS2022+WDK驱动开发环境