kaichengw1 2021-05-07 19:30 采纳率: 100%
浏览 26

lstm训练遇到瓶颈 测试集正确率在44

 

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,   
                            num_layers=num_layers,     
                            batch_first=True,
                            bidirectional = True)
        self.output_layer = nn.Linear(in_features=hidden_size*2, out_features=4)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        lstm_out, (h_n, h_c) = self.lstm(x, None)
        lstm_out = self.dropout(lstm_out)
        output = self.output_layer(lstm_out[:, -1, :])
        return output

lstm = LSTM()
lstm = lstm.float()
print(lstm)
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)
loss_function = nn.CrossEntropyLoss()
for epoch in range(epoches):
        print("进行第{}个epoch".format(epoch))
        for step, (batch_x, batch_y) in enumerate(train_loader):
            optimizer.zero_grad()
            
            batch_x = batch_x.view(-1,1,300)
            output = lstm(batch_x.float())
            
            loss = loss_function(output, batch_y.long())
            loss.backward()
            optimizer.step()
            
            if step % 50 == 0:
                test_x = dev.x.view(-1,1,300)
                test_output = lstm(test_x.float())
                pred_y = torch.max(test_output, dim=1)[1].data.numpy()
                
                accuracy = ((pred_y == dev.y.data.numpy()).astype(int).sum()) / float(dev.y.size(0))
                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)

 这是一个根据文本内容分析用户地点的lstm,正确率一直没法提高,是哪里出了问题,有没有可以改进的地方?

  • 写回答

0条回答 默认 最新

      报告相同问题?

      悬赏问题

      • ¥15 VB.NET 父窗体调取子窗体报错
      • ¥15 python海龟作图如何改代码使其最后画出来的是一个镜像翻转的图形
      • ¥15 我不明白为什么c#微软的官方api浏览器为什么不支持函数说明的检索,有支持检索函数说明的工具吗?
      • ¥15 ORBSLAM2框架跑ICL-NUIM数据集
      • ¥15 在我想检测ros是否成功安装时输入roscore出现以下
      • ¥30 老板让我做一个公司的投屏,实时显示日期,时间,安全生产的持续天数,完全没头绪啊
      • ¥15 Google Chrome 所有页面崩溃,三种解决方案都没有解决,我崩溃了
      • ¥20 使用uni-app发起网络请求,获取重定向302返回的cookie
      • ¥20 手机外部浏览器拉起微信小程序支付 (相关搜索:微信小程序)
      • ¥20 怎样通过一个网址找到其他同样模版的网址