不明真相的滑稽天使 2023-10-08 10:34 采纳率: 0%
浏览 3

【求解答】bert实现单据文本分类时,如何减少训练数据的规模

我在bert实现单据文本分类时,如何减少训练数据的规模,我只想使用前2000条数据,请各位帮帮我,以下是我的代码:

import numpy as np
from datasets import load_dataset,load_metric
from transformers import BertTokenizerFast,BertForSequenceClassification,TrainingArguments,Trainer
import requests
#加载训练数据、分词器、预训练模型和评价方法
dataset = load_dataset('glue','sst2')#path表示数据集路径,name表示子数据集
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')#分词器对象,这里使用bert的cased版本
model = BertForSequenceClassification.from_pretrained('bert-base-cased',return_dict=True)
metric = load_metric('glue','sst2')


#对训练集分词
def tokenize(examples):#使用bert分词器(tokenizer)对数据进行处理
    return tokenizer(examples['sentence'],truncation=True,padding='max_length')#句子,截断,填充

dataset = dataset.map(tokenize,batched=True)#使用dataset.map方法将tokenize函数应用于整个数据集,batched实现批处理处理。
#将数据集的label标签放入新的数据集encoded_dataset,为模型提供标签信息以进行监督学习任务
encoded_dataset = dataset.map(lambda examples:{'labels':examples['label']},batched=True)


#将数据集格式化位torch.Tensor类型以训练PyTorch模型
#指定encoded_dataset列名
columns = ['input_ids','token_type_ids','attention_mask','labels']
encoded_dataset.set_format(type="torch",columns=columns)

#定义评价指标
def compute_metrics(eval_pred):
    predictions,labels = eval_pred
    return metric.compute(predictions=np.argmax(predictions,axis=1),references=labels)

#定义训练参数TrainingArguments,默认使用AdamW优化器
args = TrainingArguments(
    "ft-sst2",#输出路径,存放检查点和其他输出文件
    evaluation_strategy="epoch",#定义每轮结束后进行评价
    learning_rate=2e-5,#初始学习率
    per_device_train_batch_size=4,#训练批次大小
    per_device_eval_batch_size=4,#测试批次大小
    num_train_epochs=2,#训练轮数
)



#定义Trainer,指定模型和训练参数,输入训练集,验证集,分词器和评价函数
trainer = Trainer(
    model,
    args,
    train_dataset = encoded_dataset["train"],
    eval_dataset = encoded_dataset["validation"],
    tokenizer = tokenizer,
    compute_metrics = compute_metrics
)

#开始训练
trainer.train()

#训练完毕后进行测试
trainer.evaluate()

ps:我尝试使用dataset = dataset["train"][:2000],但出现报错:

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
E:\study\project\python\chap7\Bert_SSC.py:9: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
  metric = load_metric('glue','sst2')
Using the latest cached version of the module from C:\Users\Administrator\.cache\huggingface\modules\datasets_modules\metrics\glue\91f3cfc5498873918ecf119dbf806fb10815786c84f41b85a5d3c47c1519b343 (last modified on Sun Oct  8 09:39:10 2023) since it couldn't be found locally at glue, or remotely on the Hugging Face Hub.
Traceback (most recent call last):
  File "E:\study\project\python\chap7\Bert_SSC.py", line 16, in <module>
    dataset = dataset.map(tokenize,batched=True)#使用dataset.map方法将tokenize函数应用于整个数据集,batched实现批处理处理。
AttributeError: 'dict' object has no attribute 'map'


  • 写回答

1条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2023-10-09 06:09
    关注

    引自强大而免费的微信小程序:皆我百晓生。


    在你的代码中,你使用了dataset = load_dataset('glue','sst2')来加载数据集,然后使用dataset = dataset.map(tokenize,batched=True)对数据集进行分词。然而,当你尝试使用dataset = dataset["train"][:2000]来减少训练数据的规模时,出现了报错。

    报错信息提示了AttributeError: 'dict' object has no attribute 'map',这说明dataset已经被转化为一个字典,而不是一个Datasets对象。这可能是因为你选择了数据集的某个子集,而不是整个数据集。因此,当你尝试对dataset进行分词时,会出现错误。

    如果你只想使用前2000条数据进行训练,你可以在加载数据集时使用.select(range(2000))来选取前2000条数据,然后再进行分词和其他处理。修改你的代码如下:

    dataset = load_dataset('glue','sst2')
    dataset = dataset['train'].select(range(2000)) # 选取前2000条训练数据
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
    ...
    

    这样就可以只使用前2000条数据进行训练了。

    评论

报告相同问题?

问题事件

  • 创建了问题 10月8日

悬赏问题

  • ¥15 35114 SVAC视频验签的问题
  • ¥15 impedancepy
  • ¥15 在虚拟机环境下完成以下,要求截图!
  • ¥15 求往届大挑得奖作品(ppt…)
  • ¥15 如何在vue.config.js中读取到public文件夹下window.APP_CONFIG.API_BASE_URL的值
  • ¥50 浦育平台scratch图形化编程
  • ¥20 求这个的原理图 只要原理图
  • ¥15 vue2项目中,如何配置环境,可以在打完包之后修改请求的服务器地址
  • ¥20 微信的店铺小程序如何修改背景图
  • ¥15 UE5.1局部变量对蓝图不可见