讨厌代码! 2023-04-21 22:34 采纳率: 50%
浏览 14
已结题

Bert中文文本分类输出

基于Pytorch框架的Bert模型中文文本分类,输出结果是如下图所示的张量Tensor([151])等,我的分类类别是0、1,一直到6这样的数字标签,请问怎么从输出的张量结果看对应的预测类别标签?
这样输出的结果是正确的吗?
谢谢回答的大家!

img

部分模型引用代码如下:

with open('data_for_classify.txt','r',encoding='utf-8') as f:
    for line in f.readlines():
        new_sentences = line.strip()
        input_ids, seq_len, attention_mask = preprocess(new_sentences)
        preprocessed_sentences = preprocess(new_sentences)
        logits = bert_model.forward(preprocessed_sentences[0], attention_mask=preprocessed_sentences[2])[1]
        #preprocessed_sentences = preprocess(new_sentences)
        #logits = bert_model.forward()
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        predicted_classes = torch.argmax(probabilities, dim=-1)
        print(predicted_classes)
  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-04-22 08:12
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:
    • 这个问题的回答你可以参考下: https://ask.csdn.net/questions/7540840
    • 这篇博客也不错, 你可以看下由浅入深,带你搞懂 Pytorch 中的张量 tensor 是什么
    • 除此之外, 这篇博客: Pytorch中使用Bert预训练模型,并给定句子得到对应的向量中的 修改输入形式 部分也许能够解决你的问题, 你可以仔细阅读以下内容或者直接跳转源博客中阅读:

      Bert的输入要有三个向量:(1)input_ids (2)token_type_ids (3)attention_mask。这三个向量可以通过一行代码获得:

      sentenceA = '等潮水退去,就知道谁没穿裤子'
      text_dict = tokenizer.encode_plus(sentenceA, add_special_tokens=True, return_attention_mask=True)
      
    • 您还可以看一下 龙良曲老师的深度学习与PyTorch入门实战教程课程中的 Tensor运算小节, 巩固相关知识点
    • 以下回答来自chatgpt:

      针对问题中的张量输出结果,可以使用PyTorch中的item()方法获取张量中的值,并使用近似函数进行取整操作得到最终的预测类别标签。

      具体实现步骤如下:

      1. 将输出的张量中的值取出,用item()方法将其转换为标量。

      2. 对取出的标量进行四舍五入的操作,使用Python内置的round()函数实现。

      3. 将四舍五入后的结果转换为整数类型,使用Python内置的int()函数实现。

      4. 将整数类型的结果作为最终的预测类别标签返回。

      代码实现如下:

      # 模型预测
      with torch.no_grad():
          output = model(input_ids, attention_mask, token_type_ids)
          logits = output[0]
          preds = torch.argmax(logits, dim=1)
          pred_label = int(round(preds.item()))
          print("预测类别标签为:", pred_label)
      

      其中,preds.item()用于获取张量中的值,int()函数用于将四舍五入后的标量转换为整数类型,round()函数用于实现四舍五入操作。


    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 3月29日
  • 创建了问题 4月21日