network爬虫 2023-12-07 22:33 采纳率: 25%
浏览 16
已结题

二分类改为多分类问题

import paddle
import paddlehub as hub
import ast
import argparse
from paddlehub.datasets.base_nlp_dataset import TextClassificationDataset


class MyDataset(TextClassificationDataset):
    # 数据集存放目录
    base_path = 'data/weibo_senti_100k'
    # 数据集的标签列表,多分类标签格式为['0', '1', '2', '3',...]
    label_list = ['0', '1', '2','3','4','5','6']

def __init__(self, tokenizer, max_seq_len: int = 128, mode: str = 'train'):
    if mode == 'train':
        data_file = 'train.tsv'
    elif mode == 'test':
        data_file = 'test.tsv'
    else:
        data_file = 'dev.tsv'
    super().__init__(
        base_path=self.base_path,
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        mode=mode,
        data_file=data_file,
        label_list=self.label_list,
        is_file_with_header=True)
if __name__ == '__main__':
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
    parser.add_argument("--use_gpu", type=ast.literal_eval, default=True,
                        help="Whether use GPU for fine-tuning, input should be True or False")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
    parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
    parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
    parser.add_argument("--checkpoint_dir", type=str, default='./ernie_checkpoint',
                        help="Directory to model checkpoint")
    parser.add_argument("--save_interval", type=int, default=1, help="Save checkpoint every n epoch.")
    args = parser.parse_args()

    # 选择模型、任务和类别数
    model = hub.Module(name='ernie_tiny', task='seq-cls', num_classes=len(MyDataset.label_list))

    train_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='train')
    dev_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='dev')
    test_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='test')

    optimizer = paddle.optimizer.Adam(learning_rate=args.learning_rate, parameters=model.parameters())
    trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=False)
    trainer.train(train_dataset, epochs=args.num_epoch, batch_size=args.batch_size, eval_dataset=dev_dataset,
                  save_interval=args.save_interval)
    # 在测试集上评估当前训练模型
    trainer.evaluate(test_dataset, batch_size=args.batch_size)

出错提示为
AssertionError: Variable Shape not match, Variable [ linear_19.w_0_moment1_0 ] need tensor with shape [1024, 7] but load set tensor with shape [1024, 3]

  • 写回答

15条回答 默认 最新

  • 专家-郭老师 Java领域新星创作者 2023-12-07 22:34
    关注

    由人工智能和答主提供,可以参考如下,如果回答的不正确,及时评论区回复,我追加回答,谢谢。


    这个错误提示是因为在加载预训练模型时,权重矩阵的形状不匹配。具体来说,线性层(linear_19.w_0_moment1_0)需要一个形状为[1024, 7]的张量,但加载的张量形状为[1024, 3]。为了解决这个问题,你需要检查预训练模型的权重矩阵,并确保它与你的模型中的线性层相匹配。你可以尝试重新下载一个与你的模型结构相匹配的预训练模型。

    修改后代码:

    import paddle
    import paddlehub as hub
    import ast
    import argparse
    from paddlehub.datasets.base_nlp_dataset import TextClassificationDataset
    
    class MyDataset(TextClassificationDataset):
        # 数据集存放目录
        base_path = 'data/weibo_senti_100k'
        # 数据集的标签列表,多分类标签格式为['0', '1', '2', '3',...]
        label_list = ['0', '1', '2','3','4','5','6']
    
    def __init__(self, tokenizer, max_seq_len: int = 128, mode: str = 'train'):
        if mode == 'train':
            data_file = 'train.tsv'
        elif mode == 'test':
            data_file = 'test.tsv'
        else:
            data_file = 'dev.tsv'
        super().__init__(
            base_path=self.base_path,
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
            mode=mode,
            data_file=data_file,
            label_list=self.label_list,
            is_file_with_header=True)
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser(__doc__)
        parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
        parser.add_argument("--use_gpu", type=ast.literal_eval, default=True,
                            help="Whether use GPU for fine-tuning, input should be True or False")
        parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
        parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
        parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
        parser.add_argument("--checkpoint_dir", type=str, default='./ernie_checkpoint',
                            help="Directory to model checkpoint")
        parser.add_argument("--save_interval", type=int, default=1, help="Save checkpoint every n epoch.")
        args = parser.parse_args()
    
        # 选择模型、任务和类别数
        model = hub.Module(name='ernie_tiny', task='seq-cls', num_classes=len(MyDataset.label_list))
    
        train_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='train')
        dev_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='dev')
        test_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=args.max_seq_len, mode='test')
    
        optimizer = paddle.optimizer.Adam(learning_rate=args.learning_rate, parameters=model.parameters())
        trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=False)
        trainer.train(train_dataset, epochs=args.num_epoch, batch_size=args.batch_size, eval_dataset=dev_dataset,
                      save_interval=args.save_interval)
        # 在测试集上评估当前训练模型
        trainer.evaluate(test_dataset, batch_size=args.batch_size)
    
    
    
    评论 编辑记录

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 12月14日
  • 创建了问题 12月7日

悬赏问题

  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用
  • ¥100 vc++混合CEF采用CLR方式编译报错
  • ¥15 coze 的插件输入飞书多维表格 app_token 后一直显示错误,如何解决?
  • ¥15 vite+vue3+plyr播放本地public文件夹下视频无法加载
  • ¥15 c#逐行读取txt文本,但是每一行里面数据之间空格数量不同
  • ¥50 如何openEuler 22.03上安装配置drbd
  • ¥20 ING91680C BLE5.3 芯片怎么实现串口收发数据
  • ¥15 无线连接树莓派,无法执行update,如何解决?(相关搜索:软件下载)
  • ¥15 Windows11, backspace, enter, space键失灵