下面Python[代码是直接调用whisper进行语音识别,是结果。任务是重新用数据集,对whisper进行微调,形成新的模型,用于新的内容的语音识别
import torch
import gc
def clear_gpu_memory():
# 释放 PyTorch 缓存的显存
torch.cuda.empty_cache()
# 强制垃圾回收,清理 Python 层面未释放的对象
gc.collect()
# 可选:重置当前 GPU 的显存分配状态
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
print("显存已清理!当前显存使用情况:")
if torch.cuda.is_available():
print(f"已分配: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"缓存: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
# 调用清理
clear_gpu_memory()
import os
import torch
import torchaudio
import pandas as pd
import numpy as np
import logging
from typing import Optional, Dict
from torch.utils.data import Dataset
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
TrainingArguments,
Trainer,
)
from jiwer import wer, cer
from IPython.display import display, clear_output
from torch.nn.utils.rnn import pad_sequence
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# 优化内存
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
# 自定义数据集
class WhisperDataset(Dataset):
def __init__(self, dataframe, audio_folder, processor, sample_rate: int = 16000, remove_punctuation: bool = True):
self.dataframe = dataframe
self.audio_folder = audio_folder
self.processor = processor
self.sample_rate = sample_rate
self.remove_punctuation = remove_punctuation
self.valid_indices = []
for idx in range(len(dataframe)):
audio_path = os.path.join(self.audio_folder, dataframe.iloc[idx]["path"])
if os.path.exists(audio_path):
self.valid_indices.append(idx)
else:
logging.warning(f"音频文件不存在: {audio_path},跳过样本 {idx}")
logging.info(f"有效样本数量: {len(self.valid_indices)} / {len(dataframe)}")
def __len__(self):
return len(self.valid_indices)
def __getitem__(self, idx):
actual_idx = self.valid_indices[idx]
row = self.dataframe.iloc[actual_idx]
audio_path = os.path.join(self.audio_folder, row["path"])
try:
waveform, sample_rate = torchaudio.load(audio_path)
waveform = waveform.mean(dim=0, keepdim=True)
except Exception as e:
logging.error(f"⚠️ 样本 {actual_idx} 音频加载失败: {str(e)},跳过")
raise
if sample_rate != self.sample_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(waveform)
max_length = self.sample_rate * 30
if waveform.shape[1] > max_length:
logging.info(f"⚠️ 样本 {actual_idx} 音频超长({waveform.shape[1] / self.sample_rate:.2f}秒),自动裁剪")
waveform = waveform[:, :max_length]
input_features = self.processor.feature_extractor(
waveform.squeeze(0).numpy(),
sampling_rate=self.sample_rate,
return_tensors="pt"
).input_features.squeeze(0)
text = row["sentence"]
if self.remove_punctuation:
text = clean_text(text)
labels = self.processor.tokenizer(
text,
return_tensors="pt",
padding=False,
truncation=True,
max_length=128
).input_ids.squeeze(0)
return {"input_features": input_features, "labels": labels}
# 预处理文本(保留原逻辑)
def clean_text(text: str) -> str:
import re
text = re.sub(r'[《》·?!,。、;:“”‘’()]', '', text)
return text.strip()
# 定义 safe_collate
def safe_collate(batch):
input_features = [item["input_features"] for item in batch]
labels = [item["labels"] for item in batch]
input_features = torch.stack(input_features)
labels = pad_sequence(labels, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
return {
"input_features": input_features,
"labels": labels
}
# 定义 compute_metrics
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if isinstance(predictions, np.ndarray) and predictions.ndim == 3:
pred_ids = np.argmax(predictions, axis=-1)
else:
pred_ids = predictions
labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
pred_texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
label_texts = processor.batch_decode(labels, skip_special_tokens=True)
pred_texts = [text.strip() if text.strip() else "<empty>" for text in pred_texts]
label_texts = [text.strip() if text.strip() else "<empty>" for text in label_texts]
try:
wer_score = wer(label_texts, pred_texts)
cer_score = cer(label_texts, pred_texts)
except Exception as e:
logging.error(f"WER/CER 计算失败: {str(e)}")
return {"wer": float("inf"), "cer": float("inf")}
return {"wer": wer_score, "cer": cer_score}
# 初始化 Processor 和 Model
processor = WhisperProcessor.from_pretrained("/mnt/whisper-small")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WhisperForConditionalGeneration.from_pretrained("/mnt/whisper-small").to(device)
# 检查并初始化 proj_out
def ensure_proj_out(model):
if model.proj_out.out_features != model.config.vocab_size:
logging.warning(
f"🔄 `proj_out` 层大小不匹配,重新初始化: {model.proj_out.out_features} → {model.config.vocab_size}"
)
old_weight = model.proj_out.weight.data.clone()
old_bias = model.proj_out.bias.data.clone() if model.proj_out.bias is not None else None
model.proj_out = torch.nn.Linear(
model.config.d_model, model.config.vocab_size, bias=old_bias is not None
)
with torch.no_grad():
min_dim = min(old_weight.shape[0], model.config.vocab_size)
model.proj_out.weight.data[:min_dim] = old_weight[:min_dim]
if old_bias is not None:
model.proj_out.bias.data[:min_dim] = old_bias[:min_dim]
torch.nn.init.normal_(
model.proj_out.weight.data[min_dim:], mean=0, std=model.config.init_std
)
if model.proj_out.out_features != model.config.vocab_size:
ensure_proj_out(model)
else:
logging.info("`proj_out` 层已匹配,无需重新初始化")
if hasattr(model.config, "generation_config"):
model.generation_config.update(model.config.to_dict())
# 验证元数据和路径
data_dir = "/mnt/zh-CN/cv-corpus-11.0-2022-09-21/zh-CN"
audio_dir = os.path.join(data_dir, "clips/")
for tsv_file in ["train.tsv", "validated.tsv", "dev.tsv"]:
tsv_path = os.path.join(data_dir, tsv_file)
if not os.path.exists(tsv_path):
raise FileNotFoundError(f"元数据文件不存在: {tsv_path}")
train_metadata = pd.read_csv(os.path.join(data_dir, "train.tsv"), sep="\t")
validated_metadata = pd.read_csv(os.path.join(data_dir, "validated.tsv"), sep="\t")
dev_metadata = pd.read_csv(os.path.join(data_dir, "dev.tsv"), sep="\t")
train_metadata = pd.concat([train_metadata, validated_metadata], ignore_index=True)
# 加载数据集
train_dataset = WhisperDataset(train_metadata, audio_dir, processor, remove_punctuation=True)
eval_dataset = WhisperDataset(dev_metadata, audio_dir, processor, remove_punctuation=True)
logging.info(f"训练集大小: {len(train_dataset)}")
logging.info(f"验证集大小: {len(eval_dataset)}")
# 🚀 训练参数
training_args = TrainingArguments(
output_dir="/mnt/whisper-checkpoints",
per_device_train_batch_size=4,
per_device_eval_batch_size=2,
gradient_accumulation_steps=4,
max_steps=5000,
learning_rate=3e-05,
lr_scheduler_type="linear",
warmup_steps=500,
seed=42,
fp16=True,
optim="adamw_torch",
max_grad_norm=1.0,
eval_strategy="steps",
eval_steps=2000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=1000,
weight_decay=0.01,
load_best_model_at_end=False,
report_to="none",
remove_unused_columns=False,
prediction_loss_only=False,
dataloader_num_workers=1,
)
# 🚀 自定义 Trainer
class FullSaveTrainer(Trainer):
def __init__(self, processor=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.processor = processor
self.training_logs = []
self.table_data = []
def save_model(self, output_dir=None, *args, **kwargs):
super().save_model(output_dir, *args, **kwargs)
if self.processor:
self.processor.save_pretrained(output_dir)
ensure_proj_out(self.model)
torch.save(
self.model.proj_out.state_dict(), os.path.join(output_dir, "proj_out.pth")
)
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
logs["epoch"] = round(logs.get("epoch", 0), 4)
self.training_logs.append(logs.copy())
super().log(logs, *args, **kwargs)
row = {
"Training Loss": logs.get("loss", "-"),
"Epoch": logs["epoch"],
"Step": logs.get("step", "-"),
"Validation Loss": logs.get("eval_loss", "-"),
"Wer": logs.get("eval_wer", "-"),
"CER": logs.get("eval_cer", "-")
}
self.table_data.append(row)
if logs.get("step", 0) % 1000 == 0:
try:
clear_output(wait=True)
df = pd.DataFrame(self.table_data)
display(df)
except NameError:
print(pd.DataFrame(self.table_data).to_string())
# 🚀 初始化 Trainer
trainer = FullSaveTrainer(
processor=processor,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=safe_collate,
compute_metrics=compute_metrics,
)
# 🚀 训练
try:
logging.info("🚀 开始训练...")
trainer.train()
logging.info("✅ 训练完成")
final_output_dir = os.path.join(training_args.output_dir, "final_model")
trainer.save_model(final_output_dir)
processor.save_pretrained(final_output_dir)
import json
with open(os.path.join(training_args.output_dir, "training_logs.json"), "w") as f:
json.dump(trainer.state.log_history, f)
logging.info(f"📊 训练日志已保存到 {os.path.join(training_args.output_dir, 'training_logs.json')}")
except Exception as e:
logging.error(f"训练失败: {str(e)}")
raise