m0_62735436 2023-06-14 17:04 采纳率: 0%
浏览 8

flask与深度学习

使用flask将LeNet5模型的运行结果和测试结果上传web

  • 写回答

2条回答 默认 最新

  • 断水流大撕兄 新星创作者: 操作系统技术领域 2023-06-14 18:00
    关注

    我这里有一个示例,你可以参考借鉴思路

    from flask import Flask, request, jsonify
    import pandas as pd
    import torch
    from torch import nn
    from torch.utils.data import DataLoader
    from torchvision import transforms, datasets
    import torch.optim as optim
    
    app = Flask(__name__)
    
    # 定义LeNet5模型
    class LeNet5(nn.Module):
        def __init__(self):
            super(LeNet5, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(256, 120)
            self.fc2 = nn.Linear(120, 84) 
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) 
            x = F.max_pool2d(F.relu(self.conv2(x)), 2) 
            x = x.view(x.size(0), -1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    # 训练好的模型参数    
    model_params = torch.load('model.pth') 
    model = LeNet5()
    model.load_state_dict(model_params)
    
    # 测试数据集
    test_set = datasets.MNIST('mnist_data', train=False, download=True, 
                               transform=transforms.Compose([
                                   transforms.ToTensor(), 
                                   transforms.Normalize((0.1307,), (0.3081,))
                               ]))
    test_loader = DataLoader(test_set, batch_size=128, shuffle=True)
    
    # 模型预测结果
    preds = []
    with torch.no_grad():
        for batch in test_loader:
            images, labels = batch 
            preds.append(model(images).max(1)[1].numpy())
    preds = np.concatenate(preds)
    
    # API
    @app.route('/predict', methods=['POST'])
    def predict():
        img = request.files['img']
        img = torch.from_numpy(preprocessing(img)).unsqueeze(0)
        with torch.no_grad():
            pred = model(img).max(1)[1].item()
        return jsonify(pred)
    
    @app.route('/accuracy', methods=['GET'])
    def accuracy():
        acc = (preds == test_set.targets.numpy()).mean()
        return jsonify(acc)
    
    if __name__ == '__main__':
        app.run(debug=True) 
    
    评论

报告相同问题?

问题事件

  • 创建了问题 6月14日