大家好,我尝试finetune BERT 模型,但是始终遇到相同的报错
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA
to enable device-side assertions.
这里是我的代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the data
df = pd.read_csv('/root/manu_data.csv')
df = df[df['label'] != 'Noise']
# Rename the column if necessary (assuming the column is named 'label' in your CSV)
df['label'] = df['label'].astype(int) - 1
train_texts, val_texts, train_labels, val_labels = train_test_split(
df['sequence'].tolist(),
df['label'].astype(int).tolist(),
test_size=0.2
)
tokenizer = AutoTokenizer.from_pretrained('/root/Model/DNABERT_s')
model = BertForSequenceClassification.from_pretrained('/root/Model/DNABERT_s', num_labels=30)
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=512)
train_dataset = Dataset.from_dict({'input_ids': train_encodings['input_ids'], 'attention_mask': train_encodings['attention_mask'], 'labels': train_labels})
val_dataset = Dataset.from_dict({'input_ids': val_encodings['input_ids'], 'attention_mask': val_encodings['attention_mask'], 'labels': val_labels})
# Define the training arguments
training_args = TrainingArguments(
output_dir='/root/Model/finetuned_DNABERT',
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=500,
weight_decay=0.01,
logging_dir='/root/Model/finetuned_DNABERT/logs',
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
save_total_limit=2,
fp16=True,
)
# Define a Trainer object
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
我检查了print(f"Unique labels in the dataset: {df['label'].unique()}"),print的结果是0-29一共30个label。麻烦大家帮我看一下那里有问题,虚心提问,谢谢。