weixin_41944061 2022-08-31 01:31 采纳率: 66.7%
浏览 190
已结题

简单cnn网络,csv数据集,bug解决

说明:跟教程学搞了一个神经网络,一共283个样本,每个样本144个特征(样本数据data_trans.csv),样本标签是10分类(标签文件data_label_10.csv),样本都是csv格式的,模型是随便抄的很简单的cnn。

问题1:在Dataloader中设置的batch_size=10,为什么在Conv1卷积层中in_channel设为1会出bug,设置为batch_size一样的10才行,样本本身就是1维的啊。
问题2:Linear层out_features设为10会报错,“mat1 and mat2 shapes cannot be multiplied (5x144 and 720x10)”,是哪个参数设置错误了啊?
问题3:请帮忙处理bug,能跑通就行。

代码和数据csv文件都在Github里,链接https://github.com/sizimiya/CSDN-Question.git

  • 写回答

2条回答 默认 最新

  • 脚踏南山 2022-08-31 09:25
    关注

    有用记得采纳,batch_size=随便改

    # -*- coding: UTF-8 -*-
    """
    @项目名称:简单cnn网络_csv数据集_bug解决.py
    @作   者:陆地起飞全靠浪
    @创建日期:2022-08-31-09:20
    https://ask.csdn.net/questions/7779923
    https://blog.csdn.net/weixin_41944061?type=ask
    """
    
    import pandas as pd
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    from tqdm import tqdm
    from datetime import datetime
    
    
    class CsvDataset(Dataset):
        def __init__(self):
            super(CsvDataset, self).__init__()
    
            self.feature_path = 'data_trans.csv'
            self.label_path = 'data_label_10.csv'
    
            feature_df_ = pd.read_csv(self.feature_path)
            label_df_ = pd.read_csv(self.label_path)
    
            assert feature_df_.columns.tolist()[1:] == label_df_[label_df_.columns[0]].tolist(), \
                'feature name does not match label name'
    
            self.feature = [feature_df_[i].tolist() for i in feature_df_.columns[1:]]
    
            self.label = label_df_[label_df_.columns[1]]
    
            assert len(self.feature) == len(self.label)
    
            self.length = len(self.feature)
    
        def __getitem__(self, index):
            x = self.feature[index]
            x = torch.Tensor(x)
            x = x.reshape(1,12, 12)
    
            y = self.label[index]
    
            return x, y
    
        def __len__(self):
            return self.length
    
    
    train_dataset = CsvDataset()
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=2,  shuffle=False)
    
    
    class SimpleModel(nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(1, 1))  # ***************************
            self.conv2 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=(1, 1))
            self.relu = nn.ReLU(inplace=True)
            self.flatten = nn.Flatten(start_dim=1, end_dim=-1)  # (B, C, H ,W)
            self.linear = nn.Linear(in_features=5 * 12 * 12, out_features=10, bias=False)  # ****************************
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.relu(x)
            x = self.conv2(x)
            x = self.relu(x)
            print("[before flatten] x.shape: {}".format(x.shape))  # torch.Size([1, 5, 12, 12])
            x = self.flatten(x)
            print("[after flatten] x.shape: {}".format(x.shape))  # torch.Size([1, 720])
            x = self.linear(x)
            x = self.relu(x)
            return x
    
    
    model = SimpleModel()
    
    optimizer = optim.SGD(params=model.parameters(), lr=0.0001, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    
    for epoch in range(2):
        with tqdm(train_loader, desc='EPOCH:{}'.format(epoch)) as train_bar:
            for (x, y) in train_bar:
                optimizer.zero_grad()
                loss = loss_fn(model(x), y)
                loss.backward()
                optimizer.step()
        print('epoch:{}, loss:{:.6f}'.format(epoch, loss))
    
    time = str(datetime.now()).split('')[0].replace('-', '_')
    torch.save(model.state_dict(), 'model_{}.pth'.format(time))
    
    
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 9月8日
  • 已采纳回答 8月31日
  • 创建了问题 8月31日

悬赏问题

  • ¥15 用ns3仿真出5G核心网网元
  • ¥15 matlab答疑 关于海上风电的爬坡事件检测
  • ¥88 python部署量化回测异常问题
  • ¥30 酬劳2w元求合作写文章
  • ¥15 在现有系统基础上增加功能
  • ¥15 远程桌面文档内容复制粘贴,格式会变化
  • ¥15 关于#java#的问题:找一份能快速看完mooc视频的代码
  • ¥15 这种微信登录授权 谁可以做啊
  • ¥15 请问我该如何添加自己的数据去运行蚁群算法代码
  • ¥20 用HslCommunication 连接欧姆龙 plc有时会连接失败。报异常为“未知错误”