有用记得采纳,batch_size=随便改
# -*- coding: UTF-8 -*-
"""
@项目名称:简单cnn网络_csv数据集_bug解决.py
@作 者:陆地起飞全靠浪
@创建日期:2022-08-31-09:20
https://ask.csdn.net/questions/7779923
https://blog.csdn.net/weixin_41944061?type=ask
"""
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datetime import datetime
class CsvDataset(Dataset):
def __init__(self):
super(CsvDataset, self).__init__()
self.feature_path = 'data_trans.csv'
self.label_path = 'data_label_10.csv'
feature_df_ = pd.read_csv(self.feature_path)
label_df_ = pd.read_csv(self.label_path)
assert feature_df_.columns.tolist()[1:] == label_df_[label_df_.columns[0]].tolist(), \
'feature name does not match label name'
self.feature = [feature_df_[i].tolist() for i in feature_df_.columns[1:]]
self.label = label_df_[label_df_.columns[1]]
assert len(self.feature) == len(self.label)
self.length = len(self.feature)
def __getitem__(self, index):
x = self.feature[index]
x = torch.Tensor(x)
x = x.reshape(1,12, 12)
y = self.label[index]
return x, y
def __len__(self):
return self.length
train_dataset = CsvDataset()
train_loader = DataLoader(dataset=train_dataset, batch_size=2, shuffle=False)
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(1, 1)) # ***************************
self.conv2 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=(1, 1))
self.relu = nn.ReLU(inplace=True)
self.flatten = nn.Flatten(start_dim=1, end_dim=-1) # (B, C, H ,W)
self.linear = nn.Linear(in_features=5 * 12 * 12, out_features=10, bias=False) # ****************************
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
print("[before flatten] x.shape: {}".format(x.shape)) # torch.Size([1, 5, 12, 12])
x = self.flatten(x)
print("[after flatten] x.shape: {}".format(x.shape)) # torch.Size([1, 720])
x = self.linear(x)
x = self.relu(x)
return x
model = SimpleModel()
optimizer = optim.SGD(params=model.parameters(), lr=0.0001, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(2):
with tqdm(train_loader, desc='EPOCH:{}'.format(epoch)) as train_bar:
for (x, y) in train_bar:
optimizer.zero_grad()
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
print('epoch:{}, loss:{:.6f}'.format(epoch, loss))
time = str(datetime.now()).split('')[0].replace('-', '_')
torch.save(model.state_dict(), 'model_{}.pth'.format(time))