是团团团子呀w 2023-02-17 13:04 采纳率: 66.7%
浏览 30
已结题

深度学习CNN中Lenet和优化器

有没有什么预训练的lenet模型
我目前代码如下

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 3 input image channel, 6 output channels, 5x5 square convolution
        self.conv1 = torch.nn.Conv2d(3, 6, kernel_size = 5, padding = 2)
        self.conv2 = torch.nn.Conv2d(6, 16, kernel_size = 5)
        self.conv3 = torch.nn.Conv2d(16, 32, kernel_size = 5)
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.dropout = nn.Dropout(p=0.1)
        self.fc1 = torch.nn.Linear(32*13*13, 120)      
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 2)
  def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.reshape(x, (x.size()[0], -1))
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

优化使用的是SGD
不知道有没有什么简单清晰的方法提高正确率
目前是91左右,希望能上93

补充:训练集不大 只有5w张图

  • 写回答

3条回答 默认 最新

  • 事实证明 2023-02-17 13:17
    关注
    
    import torch
    import torchvision.models as models
    
    # 加载预训练的LeNet模型
    model = models.lenet(pretrained=True)
    
    # 将最后一层的输出改为2(因为你的任务是二分类)
    model.classifier[-1] = torch.nn.Linear(84, 2)
    
    # 将模型转移到GPU上
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # 你的优化器和损失函数等不变
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()
    

    另外,你可以尝试以下方法来提高模型的准确率:

    增加训练数据。如果你的训练数据集不够大,可以使用数据增强技术来生成更多的训练样本。

    调整学习率。你可以尝试减小学习率,以避免模型在训练时过早收敛或者过拟合。

    调整模型结构。可以尝试增加或减少卷积层、全连接层的数量,或者调整它们的大小。

    使用正则化技术。例如dropout和L2正则化等可以有效减少过拟合问题。

    尝试使用其他的优化器,例如Adam或Adagrad等,它们可能会比SGD更好地适应你的模型。

    进行模型融合。如果你有多个模型,可以将它们的预测结果进行融合,以提高整体的准确率。

    chatGPT回答的

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 3月9日
  • 已采纳回答 3月1日
  • 修改了问题 2月17日
  • 创建了问题 2月17日

悬赏问题

  • ¥100 set_link_state
  • ¥15 虚幻5 UE美术毛发渲染
  • ¥15 CVRP 图论 物流运输优化
  • ¥15 Tableau online 嵌入ppt失败
  • ¥100 支付宝网页转账系统不识别账号
  • ¥15 基于单片机的靶位控制系统
  • ¥15 真我手机蓝牙传输进度消息被关闭了,怎么打开?(关键词-消息通知)
  • ¥15 装 pytorch 的时候出了好多问题,遇到这种情况怎么处理?
  • ¥20 IOS游览器某宝手机网页版自动立即购买JavaScript脚本
  • ¥15 手机接入宽带网线,如何释放宽带全部速度