筱萱儿 2022-07-18 14:56 采纳率: 0%
浏览 992
已结题

神经网络模型训练集和验证集的准确率一直不变

遇到的问题:
在使用BiLSTM+Multi-head-Attention模型进行训练,发现训练集和验证集的准确率一直不变。
相关代码:

class BiLSTM_Multi_head_Attention(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,pretrained_weight, update_w2v,max_sen_len,dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(pretrained_weight)  # 读取预训练好的参数  [5798,50]
        self.embedding.weight.requires_grad = update_w2v  # 控制加载的预训练模型在训练中参数是否更新
        # BiLSTM
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )
        # Multihead attention:
        self.mha = nn.MultiheadAttention(2 * hidden_dim, num_heads=8)
        # Flatten into [batch_size, 2*N_HIDDEN*N_SEQ]
        self.flatten = nn.Flatten()
        # Fully connected classifer
        self.fc1 = nn.Linear( max_sen_len*2 * hidden_dim, 1024)  # As bidirectional
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(1024, 256)
        self.dropout = nn.Dropout(dropout)
        self.fc5 = nn.Linear(256, output_dim)

    def forward(self, text):
        # Embedding of the given "text" represented as a vector
        embedded = self.embedding(text)  # [batch size, sent len, emb dim]
        # LSTM output
        lstm_output, (ht, cell) = self.lstm(embedded)  # [batch size, sent len, hid dim], [ batch size, 1, hid dim]
        # Compute attention:
        attn_output, attn_output_weights = self.mha(lstm_output, lstm_output, lstm_output)
        # Flatten:
        x = self.flatten(attn_output)
        # Classifer:
        # Layer 1
        x = self.fc1(x)
        x = F.softmax(x,dim=1)
        # Dropout
        x = self.dropout(x)
        # Layer 2
        x = self.fc2(x)
        x = F.softmax(x,dim=1)
        # Output layer
        output = self.fc5(x)


        return output  

def train(train_dataloader, model, device, epoches, lr):
    # 模型为训练模式
    model.train()
    # 将模型转化到gpu上
    model = model.to(device)
    print(model)
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # 交叉熵损失函数
    criterion = nn.CrossEntropyLoss()
    best_acc = 0.85
    # 一个epoch可以认为是一次训练循环
    for epoch in range(epoches):
        train_loss = 0.0
        correct = 0
        total = 0
        # tqdm用在dataloader上其实是对每个batch和batch总数做的进度条
        train_dataloader = tqdm.tqdm(train_dataloader)
        # 遍历每个batch size数据
        for i, data_ in (enumerate(train_dataloader)):
            # 梯度清零
            optimizer.zero_grad()
            input_, target = data_[1], data_[2]
            # 将数据类型转化为整数
            input_ = input_.type(torch.LongTensor)
            target = target.type(torch.LongTensor)
            # 将数据转换到gpu上
            input_ = input_.to(device)
            target = target.to(device)
            # 前向传播
            output = model(input_)
            # 扩充维度
            target = target.squeeze(1)
            # 损失
            loss = criterion(output, target)
            # 反向传播
            loss.backward()
            # 梯度更新
            optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(output, 1)
            # print(predicted.shape)
            # 计数
            total += target.size(0)  # 此处的size()类似numpy的shape: np.shape(train_images)[0]
            # print(target.shape)
            # 计算预测正确的个数
            correct += (predicted == target).sum().item()
            acc = 100 * correct / total
            # 评价指标F1、Recall
            F1 = f1_score(target.cpu(), predicted.cpu(), average='weighted')
            Recall = recall_score(target.cpu(), predicted.cpu(), average='micro')
            postfix = {'train_loss: {:.5f},train_acc:{:.3f}%'
                       ',F1: {:.3f}%,Recall:{:.3f}%'.format(train_loss / (i + 1),
                                                     100 * correct / total, 100 * F1, 100 * Recall)}
            # tqdm pbar.set_postfix:设置训练时的输出
            train_dataloader.set_postfix(log=postfix)

        # 计算验证集的准确率
        acc = val_accuary(model, val_dataloader, device, criterion)
        # 当准确率提升时,保存模型。
        if acc > best_acc:
            best_acc = acc
            if os.path.exists(Config.model_state_dict_path) == False:
                os.mkdir(Config.model_state_dict_path)
            save_path = 'HA/{}_epoch_{}.pkl'.format("sen_model", epoch)
            print(os.path.join(Config.model_state_dict_path, save_path))
            torch.save(model, os.path.join(Config.model_state_dict_path, save_path))
        # 恢复到训练模式
        model.train()

运行结果

img

  • 写回答

1条回答 默认 最新

  • 筱萱儿 2022-07-18 16:26
    关注

    求解答

    评论

报告相同问题?

问题事件

  • 系统已结题 7月26日
  • 创建了问题 7月18日

悬赏问题

  • ¥500 把面具戴到人脸上,请大家贡献智慧
  • ¥15 任意一个散点图自己下载其js脚本文件并做成独立的案例页面,不要作在线的,要离线状态。
  • ¥15 各位 帮我看看如何写代码,打出来的图形要和如下图呈现的一样,急
  • ¥30 c#打开word开启修订并实时显示批注
  • ¥15 如何解决ldsc的这条报错/index error
  • ¥15 VS2022+WDK驱动开发环境
  • ¥30 关于#java#的问题,请各位专家解答!
  • ¥30 vue+element根据数据循环生成多个table,如何实现最后一列 平均分合并
  • ¥20 pcf8563时钟芯片不启振
  • ¥20 pip2.40更新pip2.43时报错