Chuck Joe 2023-06-15 11:52 采纳率: 100%
浏览 31
已结题

怎样用transformer测试

Transformer预测
我现在跑通了Transformer的源码,
数据集是Multi30k德语-英语翻译,
训练结束后生成了几个.pt文件,
请问一下,我现在想进行测试,就是给它一个英语,怎样才能翻译成德语

  • 写回答

2条回答 默认 最新

  • 玥轩_521 2023-06-15 15:06
    关注

    恭喜你已经成功训练 Transformer 模型,下面是如何进行翻译的步骤:

    1.准备测试文本数据;
    2.加载预训练的模型,并设置为评估模式;
    3.对测试文本数据进行预处理,包括分词、生成输入张量等;
    4.将预处理后的输入张量输入给模型,利用模型对其进行翻译;
    5.将输出结果进行后处理,包括解码、反标记化等;
    6.输出翻译结果。

    下面是一个简单的 Python 代码示例:

    import torch
    from torchtext.data.metrics import bleu_score
    import spacy
    
    # 加载预训练的模型
    model_path = "path/to/model.pt"
    model = torch.load(model_path)
    
    # 设置为评估模式
    model.eval()
    
    # 预处理输入文本
    input_sentence = "This is a test sentence."
    nlp = spacy.load("en_core_web_sm")
    tokenized = [tok.text for tok in nlp.tokenizer(input_sentence)]
    SRC = Field(tokenize = "spacy",init_token = '<sos>',eos_token = '<eos>',lower = True,batch_first = True,)
    SRC.vocab = SRC.vocab.load('./Multi30k-vocab/transformer_spacy_multi30k_vocabulary.pt')
    SRC = SRC.cuda()
    input_tensor = SRC.process([tokenized]).cuda()
    
    # 计算翻译结果
    with torch.no_grad():
        outputs = model(input_tensor)
        outputs = outputs.argmax(dim=-1)
        outputs = outputs.cpu().numpy().tolist()[0]
    
    # 后处理翻译结果
    TRG = Field(tokenize = "spacy",init_token = '<sos>',eos_token = '<eos>',lower = True,batch_first = True,)
    TRG.vocab = TRG.vocab.load('./Multi30k-vocab/transformer_spacy_multi30k_vocabulary.pt')
    TRG = TRG.cuda()
    translated_sentence = ' '.join([TRG.vocab.itos[idx] for idx in outputs if idx not in [TRG.vocab.stoi[TRG.init_token], TRG.vocab.stoi[TRG.eos_token]]])
    
    # 输出翻译结果
    print(translated_sentence)
    

    请注意,你需要自定义 DataLoader 和 Field 来读取和预处理测试数据。此外,该示例中需要安装 torchtext 和 spacy 库。希望这可以帮助到你。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 6月23日
  • 已采纳回答 6月15日
  • 创建了问题 6月15日

悬赏问题

  • ¥15 Tree 树形控件实现单选功能,可以使用element也可以手写一个,实现全选为全选状态
  • ¥60 寻抓云闪付tn组成网页付款链接
  • ¥20 如何通过sentry收集上传Android ndk的崩溃?
  • ¥30 有关数值生成的稳健性探讨
  • ¥20 关于C/C++图形界面设计问题
  • ¥15 QT+Gstreamer框架开发视频采集,无法将waylandsink视频绑定qt窗口
  • ¥15 vs2010开发 WFP(windows filtering platform)异常
  • ¥30 8*8*25的矩阵和1*8*25的矩阵相乘
  • ¥15 Ubuntu20.04主机有两个网口,如何配置将其中一个网口用来接入外网,另一个网口用来给其他设备上网
  • ¥15 ml307r-dl如何实现录音功能