kaichengw1 2021-05-07 19:30 采纳率: 100%
浏览 26

lstm训练遇到瓶颈 测试集正确率在44

 

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,   
                            num_layers=num_layers,     
                            batch_first=True,
                            bidirectional = True)
        self.output_layer = nn.Linear(in_features=hidden_size*2, out_features=4)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        lstm_out, (h_n, h_c) = self.lstm(x, None)
        lstm_out = self.dropout(lstm_out)
        output = self.output_layer(lstm_out[:, -1, :])
        return output

lstm = LSTM()
lstm = lstm.float()
print(lstm)
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)
loss_function = nn.CrossEntropyLoss()
for epoch in range(epoches):
        print("进行第{}个epoch".format(epoch))
        for step, (batch_x, batch_y) in enumerate(train_loader):
            optimizer.zero_grad()
            
            batch_x = batch_x.view(-1,1,300)
            output = lstm(batch_x.float())
            
            loss = loss_function(output, batch_y.long())
            loss.backward()
            optimizer.step()
            
            if step % 50 == 0:
                test_x = dev.x.view(-1,1,300)
                test_output = lstm(test_x.float())
                pred_y = torch.max(test_output, dim=1)[1].data.numpy()
                
                accuracy = ((pred_y == dev.y.data.numpy()).astype(int).sum()) / float(dev.y.size(0))
                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)

 这是一个根据文本内容分析用户地点的lstm,正确率一直没法提高,是哪里出了问题,有没有可以改进的地方?

  • 写回答

0条回答 默认 最新

    报告相同问题?

    悬赏问题

    • ¥15 c程序不知道为什么得不到结果
    • ¥40 复杂的限制性的商函数处理
    • ¥15 程序不包含适用于入口点的静态Main方法
    • ¥15 素材场景中光线烘焙后灯光失效
    • ¥15 请教一下各位,为什么我这个没有实现模拟点击
    • ¥15 执行 virtuoso 命令后,界面没有,cadence 启动不起来
    • ¥50 comfyui下连接animatediff节点生成视频质量非常差的原因
    • ¥20 有关区间dp的问题求解
    • ¥15 多电路系统共用电源的串扰问题
    • ¥15 slam rangenet++配置