weixin_41944061 2022-08-31 01:31 采纳率: 66.7%
浏览 190
已结题

简单cnn网络,csv数据集,bug解决

说明:跟教程学搞了一个神经网络,一共283个样本,每个样本144个特征(样本数据data_trans.csv),样本标签是10分类(标签文件data_label_10.csv),样本都是csv格式的,模型是随便抄的很简单的cnn。

问题1:在Dataloader中设置的batch_size=10,为什么在Conv1卷积层中in_channel设为1会出bug,设置为batch_size一样的10才行,样本本身就是1维的啊。
问题2:Linear层out_features设为10会报错,“mat1 and mat2 shapes cannot be multiplied (5x144 and 720x10)”,是哪个参数设置错误了啊?
问题3:请帮忙处理bug,能跑通就行。

代码和数据csv文件都在Github里,链接https://github.com/sizimiya/CSDN-Question.git

  • 写回答

2条回答 默认 最新

  • 脚踏南山 2022-08-31 09:25
    关注

    有用记得采纳,batch_size=随便改

    # -*- coding: UTF-8 -*-
    """
    @项目名称:简单cnn网络_csv数据集_bug解决.py
    @作   者:陆地起飞全靠浪
    @创建日期:2022-08-31-09:20
    https://ask.csdn.net/questions/7779923
    https://blog.csdn.net/weixin_41944061?type=ask
    """
    
    import pandas as pd
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    from tqdm import tqdm
    from datetime import datetime
    
    
    class CsvDataset(Dataset):
        def __init__(self):
            super(CsvDataset, self).__init__()
    
            self.feature_path = 'data_trans.csv'
            self.label_path = 'data_label_10.csv'
    
            feature_df_ = pd.read_csv(self.feature_path)
            label_df_ = pd.read_csv(self.label_path)
    
            assert feature_df_.columns.tolist()[1:] == label_df_[label_df_.columns[0]].tolist(), \
                'feature name does not match label name'
    
            self.feature = [feature_df_[i].tolist() for i in feature_df_.columns[1:]]
    
            self.label = label_df_[label_df_.columns[1]]
    
            assert len(self.feature) == len(self.label)
    
            self.length = len(self.feature)
    
        def __getitem__(self, index):
            x = self.feature[index]
            x = torch.Tensor(x)
            x = x.reshape(1,12, 12)
    
            y = self.label[index]
    
            return x, y
    
        def __len__(self):
            return self.length
    
    
    train_dataset = CsvDataset()
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=2,  shuffle=False)
    
    
    class SimpleModel(nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(1, 1))  # ***************************
            self.conv2 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=(1, 1))
            self.relu = nn.ReLU(inplace=True)
            self.flatten = nn.Flatten(start_dim=1, end_dim=-1)  # (B, C, H ,W)
            self.linear = nn.Linear(in_features=5 * 12 * 12, out_features=10, bias=False)  # ****************************
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.relu(x)
            x = self.conv2(x)
            x = self.relu(x)
            print("[before flatten] x.shape: {}".format(x.shape))  # torch.Size([1, 5, 12, 12])
            x = self.flatten(x)
            print("[after flatten] x.shape: {}".format(x.shape))  # torch.Size([1, 720])
            x = self.linear(x)
            x = self.relu(x)
            return x
    
    
    model = SimpleModel()
    
    optimizer = optim.SGD(params=model.parameters(), lr=0.0001, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    
    for epoch in range(2):
        with tqdm(train_loader, desc='EPOCH:{}'.format(epoch)) as train_bar:
            for (x, y) in train_bar:
                optimizer.zero_grad()
                loss = loss_fn(model(x), y)
                loss.backward()
                optimizer.step()
        print('epoch:{}, loss:{:.6f}'.format(epoch, loss))
    
    time = str(datetime.now()).split('')[0].replace('-', '_')
    torch.save(model.state_dict(), 'model_{}.pth'.format(time))
    
    
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 9月8日
  • 已采纳回答 8月31日
  • 创建了问题 8月31日

悬赏问题

  • ¥15 echarts动画效果失效的问题。官网下载的例子。
  • ¥60 许可证msc licensing软件报错显示已有相同版本软件,但是下一步显示无法读取日志目录。
  • ¥15 Attention is all you need 的代码运行
  • ¥15 一个服务器已经有一个系统了如果用usb再装一个系统,原来的系统会被覆盖掉吗
  • ¥15 使用esm_msa1_t12_100M_UR50S蛋白质语言模型进行零样本预测时,终端显示出了sequence handled的进度条,但是并不出结果就自动终止回到命令提示行了是怎么回事:
  • ¥15 前置放大电路与功率放大电路相连放大倍数出现问题
  • ¥30 关于<main>标签页面跳转的问题
  • ¥80 部署运行web自动化项目
  • ¥15 腾讯云如何建立同一个项目中物模型之间的联系
  • ¥30 VMware 云桌面水印如何添加