kaichengw1 2021-05-07 19:30 采纳率: 100%
浏览 26

lstm训练遇到瓶颈 测试集正确率在44

 

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,   
                            num_layers=num_layers,     
                            batch_first=True,
                            bidirectional = True)
        self.output_layer = nn.Linear(in_features=hidden_size*2, out_features=4)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        lstm_out, (h_n, h_c) = self.lstm(x, None)
        lstm_out = self.dropout(lstm_out)
        output = self.output_layer(lstm_out[:, -1, :])
        return output

lstm = LSTM()
lstm = lstm.float()
print(lstm)
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)
loss_function = nn.CrossEntropyLoss()
for epoch in range(epoches):
        print("进行第{}个epoch".format(epoch))
        for step, (batch_x, batch_y) in enumerate(train_loader):
            optimizer.zero_grad()
            
            batch_x = batch_x.view(-1,1,300)
            output = lstm(batch_x.float())
            
            loss = loss_function(output, batch_y.long())
            loss.backward()
            optimizer.step()
            
            if step % 50 == 0:
                test_x = dev.x.view(-1,1,300)
                test_output = lstm(test_x.float())
                pred_y = torch.max(test_output, dim=1)[1].data.numpy()
                
                accuracy = ((pred_y == dev.y.data.numpy()).astype(int).sum()) / float(dev.y.size(0))
                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)

 这是一个根据文本内容分析用户地点的lstm,正确率一直没法提高,是哪里出了问题,有没有可以改进的地方?

  • 写回答

0条回答 默认 最新

      报告相同问题?

      相关推荐 更多相似问题

      悬赏问题

      • ¥15 vscode系统开发
      • ¥15 看看怎么编程?运用函数
      • ¥15 求解,老毛子skokina是什么东西,是否存在
      • ¥15 git回滚后怎么再恢复
      • ¥15 轴承故障诊断,CDAE之后加傅里叶变换FFT,然后输入到BiLSTM中去,这个傅里叶变换该怎么加。
      • ¥15 哪位硬件专家帮助分析下OPFC无输出 原因,电压都已标注出来,芯片是FAN4800A 。
      • ¥15 启动navicat时报10061错
      • ¥20 关于#pcb工艺#的问题:只需要设计图和设计原理
      • ¥15 关于#gstreamer webrtcbin#的问题,如何解决?
      • ¥15 怎么用c语言函数编写宿舍财务管理系统?