network爬虫 2023-12-07 22:33 采纳率: 0%
浏览 14
已结题

二分类改为多分类问题

import paddle
import paddlehub as hub
import ast
import argparse
from paddlehub.datasets.base_nlp_dataset import TextClassificationDataset


class MyDataset(TextClassificationDataset):
    # 数据集存放目录
    base_path = 'data/weibo_senti_100k'
    # 数据集的标签列表,多分类标签格式为['0', '1', '2', '3',...]
    label_list = ['0', '1', '2','3','4','5','6']

def __init__(self, tokenizer, max_seq_len: int = 128, mode: str = 'train'):
    if mode == 'train':
        data_file = 'train.tsv'
    elif mode == 'test':
        data_file = 'test.tsv'
    else:
        data_file = 'dev.tsv'
    super().__init__(
        base_path=self.base_path,
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        mode=mode,
        data_file=data_file,
        label_list=self.label_list,
        is_file_with_header=True)
if __name__ == '__main__':
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
    parser.add_argument("--use_gpu", type=ast.literal_eval, default=True,
                        help="Whether use GPU for fine-tuning, input should be True or False")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
    parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
    parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
    parser.add_argument("--checkpoint_dir", type=str, default='./ernie_checkpoint',
                        help="Directory to model checkpoint")
    parser.add_argument("--save_interval", type=int, default=1, help="Save checkpoint every n epoch.")
    args = parser.parse_args()

    # 选择模型、任务和类别数
    model = hub.Module(name='ernie_tiny', task='seq-cls', num_classes=len(MyDataset.label_list))

    train_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='train')
    dev_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='dev')
    test_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='test')

    optimizer = paddle.optimizer.Adam(learning_rate=args.learning_rate, parameters=model.parameters())
    trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=False)
    trainer.train(train_dataset, epochs=args.num_epoch, batch_size=args.batch_size, eval_dataset=dev_dataset,
                  save_interval=args.save_interval)
    # 在测试集上评估当前训练模型
    trainer.evaluate(test_dataset, batch_size=args.batch_size)

出错提示为
AssertionError: Variable Shape not match, Variable [ linear_19.w_0_moment1_0 ] need tensor with shape [1024, 7] but load set tensor with shape [1024, 3]

  • 写回答

15条回答 默认 最新

  • 郭老师的小迷弟雅思莫了 Java领域新星创作者 2023-12-07 22:34
    关注

    由人工智能和答主提供,可以参考如下,如果回答的不正确,及时评论区回复,我追加回答,谢谢。


    这个错误提示是因为在加载预训练模型时,权重矩阵的形状不匹配。具体来说,线性层(linear_19.w_0_moment1_0)需要一个形状为[1024, 7]的张量,但加载的张量形状为[1024, 3]。为了解决这个问题,你需要检查预训练模型的权重矩阵,并确保它与你的模型中的线性层相匹配。你可以尝试重新下载一个与你的模型结构相匹配的预训练模型。

    修改后代码:

    import paddle
    import paddlehub as hub
    import ast
    import argparse
    from paddlehub.datasets.base_nlp_dataset import TextClassificationDataset
    
    class MyDataset(TextClassificationDataset):
        # 数据集存放目录
        base_path = 'data/weibo_senti_100k'
        # 数据集的标签列表,多分类标签格式为['0', '1', '2', '3',...]
        label_list = ['0', '1', '2','3','4','5','6']
    
    def __init__(self, tokenizer, max_seq_len: int = 128, mode: str = 'train'):
        if mode == 'train':
            data_file = 'train.tsv'
        elif mode == 'test':
            data_file = 'test.tsv'
        else:
            data_file = 'dev.tsv'
        super().__init__(
            base_path=self.base_path,
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
            mode=mode,
            data_file=data_file,
            label_list=self.label_list,
            is_file_with_header=True)
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser(__doc__)
        parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
        parser.add_argument("--use_gpu", type=ast.literal_eval, default=True,
                            help="Whether use GPU for fine-tuning, input should be True or False")
        parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
        parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
        parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
        parser.add_argument("--checkpoint_dir", type=str, default='./ernie_checkpoint',
                            help="Directory to model checkpoint")
        parser.add_argument("--save_interval", type=int, default=1, help="Save checkpoint every n epoch.")
        args = parser.parse_args()
    
        # 选择模型、任务和类别数
        model = hub.Module(name='ernie_tiny', task='seq-cls', num_classes=len(MyDataset.label_list))
    
        train_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='train')
        dev_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='dev')
        test_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='test')
    
        optimizer = paddle.optimizer.Adam(learning_rate=args.learning_rate, parameters=model.parameters())
        trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=False)
        trainer.train(train_dataset, epochs=args.num_epoch, batch_size=args.batch_size, eval_dataset=dev_dataset,
                      save_interval=args.save_interval)
        # 在测试集上评估当前训练模型
        trainer.evaluate(test_dataset, batch_size=args.batch_size)
    
    
    
    评论 编辑记录

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 12月14日
  • 创建了问题 12月7日

悬赏问题

  • ¥20 一个python博客项目的相关图例
  • ¥15 轮廓提取也提取不到,有没有别的方法,如何解决?
  • ¥50 Js和c++如何将含有图片的excel文件上传到后台服务器
  • ¥15 光电神经网络,FPGA
  • ¥20 通过防火墙出入站阻止游戏程序联网失效
  • ¥15 鼠标是可以在QT界面上移动的,但是热拔插鼠标无法移动了同时板子上是没问题的,如何解决?
  • ¥15 iframe嵌套显示问题
  • ¥20 【UE4】别人打包好的ue4游戏我该如何在自己的ue4引擎上运行
  • ¥15 power bi发布的链接地址打不开
  • ¥15 pip list列表中有库,但是编译时就显示缺少库