kaichengw1 2021-05-07 19:30 采纳率: 100%
浏览 26

lstm训练遇到瓶颈 测试集正确率在44

 

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,   
                            num_layers=num_layers,     
                            batch_first=True,
                            bidirectional = True)
        self.output_layer = nn.Linear(in_features=hidden_size*2, out_features=4)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        lstm_out, (h_n, h_c) = self.lstm(x, None)
        lstm_out = self.dropout(lstm_out)
        output = self.output_layer(lstm_out[:, -1, :])
        return output

lstm = LSTM()
lstm = lstm.float()
print(lstm)
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)
loss_function = nn.CrossEntropyLoss()
for epoch in range(epoches):
        print("进行第{}个epoch".format(epoch))
        for step, (batch_x, batch_y) in enumerate(train_loader):
            optimizer.zero_grad()
            
            batch_x = batch_x.view(-1,1,300)
            output = lstm(batch_x.float())
            
            loss = loss_function(output, batch_y.long())
            loss.backward()
            optimizer.step()
            
            if step % 50 == 0:
                test_x = dev.x.view(-1,1,300)
                test_output = lstm(test_x.float())
                pred_y = torch.max(test_output, dim=1)[1].data.numpy()
                
                accuracy = ((pred_y == dev.y.data.numpy()).astype(int).sum()) / float(dev.y.size(0))
                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)

 这是一个根据文本内容分析用户地点的lstm,正确率一直没法提高,是哪里出了问题,有没有可以改进的地方?

  • 写回答

0条回答 默认 最新

    报告相同问题?

    悬赏问题

    • ¥15 聚类分析或者python进行数据分析
    • ¥15 逻辑谓词和消解原理的运用
    • ¥15 三菱伺服电机按启动按钮有使能但不动作
    • ¥15 js,页面2返回页面1时定位进入的设备
    • ¥50 导入文件到网吧的电脑并且在重启之后不会被恢复
    • ¥15 (希望可以解决问题)ma和mb文件无法正常打开,打开后是空白,但是有正常内存占用,但可以在打开Maya应用程序后打开场景ma和mb格式。
    • ¥20 ML307A在使用AT命令连接EMQX平台的MQTT时被拒绝
    • ¥20 腾讯企业邮箱邮件可以恢复么
    • ¥15 有人知道怎么将自己的迁移策略布到edgecloudsim上使用吗?
    • ¥15 错误 LNK2001 无法解析的外部符号