每天都在头秃 2023-05-12 14:04 采纳率: 96.7%
浏览 23
已结题

Python图像分类遇到的问题

写Python图像分类时遇到的问题,本人Python学的不到位,可以帮忙看看嘛

from __future__ import print_function
import argparse

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

from utils.config_utils import read_args, load_config, Dict2Object


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch, ):
    correct = 0
    test_loss = 0
    """
    tain the model and return the training accuracy
    :param args: input arguments
    :param model: neural network model
    :param device: the device where model stored
    :param train_loader: data loader
    :param optimizer: optimizer
    :param epoch: current epoch
    :return:
    """
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        predicted = output.argmax(dim=1, keepdims=True)
        correct += (predicted == target).sum().item()
        test_loss += loss.data.item()

    training_acc, training_loss = 100. * correct / len(train_loader.dataset), test_loss / len(train_loader.dataset)
    return training_acc, training_loss


def test(model, device, test_loader):
    """
    test the model and return the tesing accuracy
    :param model: neural network model
    :param device: the device where model stored
    :param test_loader: data loader
    :return:
    """
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, dim=1)
            test_loss += F.nll_loss(outputs, target, reduction='sum').item
            correct += (predicted == target).sum().item()

    testing_acc, testing_loss = 100. * correct / len(test_loader.dataset), test_loss / len(test_loader.dataset)
    return testing_acc, testing_loss


def plot(epoches, performance):
    """
    plot the model peformance
    :param epoches: recorded epoches
    :param performance: recorded performance
    :return:
    """
    plt.title('model performance')
    plt.xlabel('epoches')
    plt.ylabel('performance')
    plt.plot(epoches,performance)
    plt.show()
    pass


def run(config):
    use_cuda = not config.no_cuda and torch.cuda.is_available()
    use_mps = not config.no_mps and torch.backends.mps.is_available()

    torch.manual_seed(config.seed)

    if use_cuda:
        device = torch.device("cuda")
    elif use_mps:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    train_kwargs = {'batch_size': config.batch_size, 'shuffle': True}
    test_kwargs = {'batch_size': config.test_batch_size, 'shuffle': True}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True, }
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    # download data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
    dataset2 = datasets.MNIST('./data', train=False, transform=transform)

    """add random seed to the DataLoader, pls modify this function"""
    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=config.lr)

    """record the performance"""
    epoches = []
    training_accuracies = []
    training_loss = []
    testing_accuracies = []
    testing_loss = []

    scheduler = StepLR(optimizer, step_size=1, gamma=config.gamma)
    for epoch in range(1, config.epochs + 1):
        train_acc, train_loss = train(config, model, device, train_loader, optimizer, epoch)
        """record training info, Fill your code"""
        test_acc, test_loss = test(model, device, test_loader)
        """record testing info, Fill your code"""
        scheduler.step()
        """update the records, Fill your code"""

    """plotting training performance with the records"""
    plot(epoches, training_loss)

    """plotting testing performance with the records"""
    plot(epoches, testing_accuracies)
    plot(epoches, testing_loss)

    if config.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")






if __name__ == '__main__':
    arg = read_args()

    """toad training settings"""
    config = load_config(arg)

    """train model and record results"""
    run(config)

   


问题报错

```python
test setup failed
file C:\Users\admin\Desktop\CSC1004-python-project-main\main.py, line 71
  def test(model, device, test_loader):
E       fixture 'model' not found
>       available fixtures: anyio_backend, anyio_backend_name, anyio_backend_options, cache, capfd, capfdbinary, caplog, capsys, capsysbinary, doctest_namespace, monkeypatch, pytestconfig, record_property, record_testsuite_property, record_xml_attribute, recwarn, tmp_path, tmp_path_factory, tmpdir, tmpdir_factory
>       use 'pytest --fixtures [testpath]' for help on them.


```
`

  • 写回答

3条回答 默认 最新

  • 赵4老师 2023-05-12 14:18
    关注

    model参数的值不是6 >available...列出的可用值
    python调试三板斧 https://ask.csdn.net/questions/7908322/54130133

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 5月20日
  • 已采纳回答 5月12日
  • 修改了问题 5月12日
  • 创建了问题 5月12日

悬赏问题

  • ¥15 pdfplumber提起文本内容如何过滤水印内容
  • ¥15 kingbase容器启动失败,sudo:/bin/sudo must be owned by uid 0 and have the setuid bit set
  • ¥20 黑神话悟空调用机械硬盘导致卡顿
  • ¥15 vue中使用antv-x6
  • ¥15 qt编译失败,环境变量已加,但还是报错
  • ¥15 用顺序表实现学生信息的管理: 学生类(Student):学号(no)、姓名(name)、分数(double),对学生信息进行存取、插入、删除、查找操作,给出代码和运行结果图。
  • ¥15 minted包显示缩进符的问题
  • ¥15 根据图片连接电路51单片机proteus汇编语言仿真4位数码管静态显示
  • ¥15 .net项目集成阿里云智能语音SDK
  • ¥15 c#关于WPS中网格线隐藏的属性