pytorch 入门 图像识别问题
pytorch入门抄了一下网上的代码,想让他识别手写数字。不知道为什么训练效果总是很差,有没有大佬帮忙解释一下。
代码以及手写图像如下:(结果从来没有得到过7,一般是1 4 5 不知道为什么
import torch
import torchvision
from PIL import Image
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
# prepare dataset
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
image = torchvision.transforms.functional.to_tensor(Image.open('D:/7.png'))
image = image.view(1,1,28,28)
# design model using class
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(320, 10)
def forward(self, x):
# flatten data from (n,1,28,28) to (n, 784)
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = x.view(batch_size, -1) # -1 此处自动算出的是320
x = self.fc(x)
return x
model = Net()
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# training cycle forward, backward, update
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
running_loss = 0.0
def test():
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0个维度,行是第1个维度
print((outputs))
print(predicted.data)
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
- 点赞
- 写回答
- 关注问题
- 收藏
- 复制链接分享
- 邀请回答
为你推荐
- 跑pytorch时出现Parallel dataloader disabled
- Pytorch使用tensorboard报错?
- 深度学习
- python
- tensorflow
- 1个回答
- pytorch这是什么错误?
- 关于使用pytorch构建GRU
- python
- 8个回答
- ML 工具 keras, tensorflow2.0,pytorch
- python
- 2个回答
- pytorch报错CUDA error: invalid device function
- pytorch point源代码出现KeyError: Caught KeyError in DataLoader worker process 0.如何解决?
- pytorch的MNIST代码中loss输出的疑问
- pytorch空权重运行,按理说是随机数,为何每次运行结果相同?
- pytorch利用卷积神经网络实现验证码识别,但是在写测试集的准确率函数时遇到问题
- 用anaconda自己配置了一个pytorch环境,那么我可以把base环境下的依赖库的文件夹复制到pytorch环境下,省去自己安装的步骤么????
- 关于 pytorch中Tensor数据类型的使用问题
- pytorch sum()结果为什么不正确
- pytorch cuda版运行出错 invalid start byte
- pytorch训练LSTM模型的代码疑问
- pytorch图像数据集怎么进行交叉验证
- 神经网络
- 1个回答
- 找DAN,DDC,JAN,RTN,simNet,ResNet-50等模型的pytorch框架代码。能找几个是几个。
- pytorch加载model发现key的值有差异,能不能修改
- pytorch RuntimeError: already started
- pytorch安装后不能import