e555666ee 2025-06-11 08:21 采纳率: 40%
浏览 14

语音识别问题,对whisper进行微调

下面Python[代码是直接调用whisper进行语音识别,是结果。任务是重新用数据集,对whisper进行微调,形成新的模型,用于新的内容的语音识别

import torch
import gc

def clear_gpu_memory():
    # 释放 PyTorch 缓存的显存
    torch.cuda.empty_cache()
    
    # 强制垃圾回收,清理 Python 层面未释放的对象
    gc.collect()
    
    # 可选:重置当前 GPU 的显存分配状态
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
    
    print("显存已清理!当前显存使用情况:")
    if torch.cuda.is_available():
        print(f"已分配: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"缓存: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

# 调用清理
clear_gpu_memory()
import os
import torch
import torchaudio
import pandas as pd
import numpy as np
import logging
from typing import Optional, Dict
from torch.utils.data import Dataset
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    TrainingArguments,
    Trainer,
)
from jiwer import wer, cer
from IPython.display import display, clear_output
from torch.nn.utils.rnn import pad_sequence

# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# 优化内存
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()

# 自定义数据集
class WhisperDataset(Dataset):
    def __init__(self, dataframe, audio_folder, processor, sample_rate: int = 16000, remove_punctuation: bool = True):
        self.dataframe = dataframe
        self.audio_folder = audio_folder
        self.processor = processor
        self.sample_rate = sample_rate
        self.remove_punctuation = remove_punctuation

        self.valid_indices = []
        for idx in range(len(dataframe)):
            audio_path = os.path.join(self.audio_folder, dataframe.iloc[idx]["path"])
            if os.path.exists(audio_path):
                self.valid_indices.append(idx)
            else:
                logging.warning(f"音频文件不存在: {audio_path},跳过样本 {idx}")

        logging.info(f"有效样本数量: {len(self.valid_indices)} / {len(dataframe)}")

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        actual_idx = self.valid_indices[idx]
        row = self.dataframe.iloc[actual_idx]
        audio_path = os.path.join(self.audio_folder, row["path"])

        try:
            waveform, sample_rate = torchaudio.load(audio_path)
            waveform = waveform.mean(dim=0, keepdim=True)
        except Exception as e:
            logging.error(f"⚠️ 样本 {actual_idx} 音频加载失败: {str(e)},跳过")
            raise

        if sample_rate != self.sample_rate:
            waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(waveform)

        max_length = self.sample_rate * 30
        if waveform.shape[1] > max_length:
            logging.info(f"⚠️ 样本 {actual_idx} 音频超长({waveform.shape[1] / self.sample_rate:.2f}秒),自动裁剪")
            waveform = waveform[:, :max_length]

        input_features = self.processor.feature_extractor(
            waveform.squeeze(0).numpy(),
            sampling_rate=self.sample_rate,
            return_tensors="pt"
        ).input_features.squeeze(0)

        text = row["sentence"]
        if self.remove_punctuation:
            text = clean_text(text)

        labels = self.processor.tokenizer(
            text,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=128
        ).input_ids.squeeze(0)

        return {"input_features": input_features, "labels": labels}

# 预处理文本(保留原逻辑)
def clean_text(text: str) -> str:
    import re
    text = re.sub(r'[《》·?!,。、;:“”‘’()]', '', text)
    return text.strip()

# 定义 safe_collate
def safe_collate(batch):
    input_features = [item["input_features"] for item in batch]
    labels = [item["labels"] for item in batch]
    
    input_features = torch.stack(input_features)
    labels = pad_sequence(labels, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
    
    return {
        "input_features": input_features,
        "labels": labels
    }

# 定义 compute_metrics
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, np.ndarray) and predictions.ndim == 3:
        pred_ids = np.argmax(predictions, axis=-1)
    else:
        pred_ids = predictions

    labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
    pred_texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_texts = processor.batch_decode(labels, skip_special_tokens=True)

    pred_texts = [text.strip() if text.strip() else "<empty>" for text in pred_texts]
    label_texts = [text.strip() if text.strip() else "<empty>" for text in label_texts]

    try:
        wer_score = wer(label_texts, pred_texts)
        cer_score = cer(label_texts, pred_texts)
    except Exception as e:
        logging.error(f"WER/CER 计算失败: {str(e)}")
        return {"wer": float("inf"), "cer": float("inf")}

    return {"wer": wer_score, "cer": cer_score}

# 初始化 Processor 和 Model
processor = WhisperProcessor.from_pretrained("/mnt/whisper-small")
processor.tokenizer.pad_token = processor.tokenizer.eos_token

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WhisperForConditionalGeneration.from_pretrained("/mnt/whisper-small").to(device)

# 检查并初始化 proj_out
def ensure_proj_out(model):
    if model.proj_out.out_features != model.config.vocab_size:
        logging.warning(
            f"🔄 `proj_out` 层大小不匹配,重新初始化: {model.proj_out.out_features} → {model.config.vocab_size}"
        )
        old_weight = model.proj_out.weight.data.clone()
        old_bias = model.proj_out.bias.data.clone() if model.proj_out.bias is not None else None
        model.proj_out = torch.nn.Linear(
            model.config.d_model, model.config.vocab_size, bias=old_bias is not None
        )
        with torch.no_grad():
            min_dim = min(old_weight.shape[0], model.config.vocab_size)
            model.proj_out.weight.data[:min_dim] = old_weight[:min_dim]
            if old_bias is not None:
                model.proj_out.bias.data[:min_dim] = old_bias[:min_dim]
            torch.nn.init.normal_(
                model.proj_out.weight.data[min_dim:], mean=0, std=model.config.init_std
            )

if model.proj_out.out_features != model.config.vocab_size:
    ensure_proj_out(model)
else:
    logging.info("`proj_out` 层已匹配,无需重新初始化")

if hasattr(model.config, "generation_config"):
    model.generation_config.update(model.config.to_dict())

# 验证元数据和路径
data_dir = "/mnt/zh-CN/cv-corpus-11.0-2022-09-21/zh-CN"
audio_dir = os.path.join(data_dir, "clips/")
for tsv_file in ["train.tsv", "validated.tsv", "dev.tsv"]:
    tsv_path = os.path.join(data_dir, tsv_file)
    if not os.path.exists(tsv_path):
        raise FileNotFoundError(f"元数据文件不存在: {tsv_path}")

train_metadata = pd.read_csv(os.path.join(data_dir, "train.tsv"), sep="\t")
validated_metadata = pd.read_csv(os.path.join(data_dir, "validated.tsv"), sep="\t")
dev_metadata = pd.read_csv(os.path.join(data_dir, "dev.tsv"), sep="\t")

train_metadata = pd.concat([train_metadata, validated_metadata], ignore_index=True)

# 加载数据集
train_dataset = WhisperDataset(train_metadata, audio_dir, processor, remove_punctuation=True)
eval_dataset = WhisperDataset(dev_metadata, audio_dir, processor, remove_punctuation=True)

logging.info(f"训练集大小: {len(train_dataset)}")
logging.info(f"验证集大小: {len(eval_dataset)}")

# 🚀 训练参数
training_args = TrainingArguments(
    output_dir="/mnt/whisper-checkpoints",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    max_steps=5000,
    learning_rate=3e-05,
    lr_scheduler_type="linear",
    warmup_steps=500,
    seed=42,
    fp16=True,
    optim="adamw_torch",
    max_grad_norm=1.0,
    eval_strategy="steps",
    eval_steps=2000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=1000,
    weight_decay=0.01,
    load_best_model_at_end=False,
    report_to="none",
    remove_unused_columns=False,
    prediction_loss_only=False,
    dataloader_num_workers=1,
)

# 🚀 自定义 Trainer
class FullSaveTrainer(Trainer):
    def __init__(self, processor=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.processor = processor
        self.training_logs = []
        self.table_data = []

    def save_model(self, output_dir=None, *args, **kwargs):
        super().save_model(output_dir, *args, **kwargs)
        if self.processor:
            self.processor.save_pretrained(output_dir)
        ensure_proj_out(self.model)
        torch.save(
            self.model.proj_out.state_dict(), os.path.join(output_dir, "proj_out.pth")
        )

    def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
        logs["epoch"] = round(logs.get("epoch", 0), 4)
        self.training_logs.append(logs.copy())
        super().log(logs, *args, **kwargs)

        row = {
            "Training Loss": logs.get("loss", "-"),
            "Epoch": logs["epoch"],
            "Step": logs.get("step", "-"),
            "Validation Loss": logs.get("eval_loss", "-"),
            "Wer": logs.get("eval_wer", "-"),
            "CER": logs.get("eval_cer", "-")
        }
        self.table_data.append(row)

        if logs.get("step", 0) % 1000 == 0:
            try:
                clear_output(wait=True)
                df = pd.DataFrame(self.table_data)
                display(df)
            except NameError:
                print(pd.DataFrame(self.table_data).to_string())

# 🚀 初始化 Trainer
trainer = FullSaveTrainer(
    processor=processor,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=safe_collate,
    compute_metrics=compute_metrics,
)

# 🚀 训练
try:
    logging.info("🚀 开始训练...")
    trainer.train()
    logging.info("✅ 训练完成")
    
    final_output_dir = os.path.join(training_args.output_dir, "final_model")
    trainer.save_model(final_output_dir)
    processor.save_pretrained(final_output_dir)
    
    import json
    with open(os.path.join(training_args.output_dir, "training_logs.json"), "w") as f:
        json.dump(trainer.state.log_history, f)
    logging.info(f"📊 训练日志已保存到 {os.path.join(training_args.output_dir, 'training_logs.json')}")

except Exception as e:
    logging.error(f"训练失败: {str(e)}")
    raise

  • 写回答

7条回答 默认 最新

  • 阿里嘎多学长 2025-06-11 08:21
    关注

    阿里嘎多学长整理AIGC生成,因移动端显示问题导致当前答案未能完全显示,请使用PC端查看更加详细的解答过程

    问题解答

    你想对 Whisper 进行微调,以便在新的数据集上进行语音识别。 Whisper 是一个基于 transformer 的语音识别模型,可以使用 transfer learning 来微调模型。

    以下是你可以尝试的步骤:

    1. 收集新的数据集,并将其分为训练集、验证集和测试集。
    2. 使用 Whisper 的预训练模型作为初始模型。
    3. 将新的数据集用于微调模型,使用 Adam 优化器和交叉熵损失函数。
    4. 在验证集上评估模型的性能,并调整超参数以提高模型的准确性。
    5. 使用最终的模型对新的内容进行语音识别。

    以下是一个简单的示例代码,用于微调 Whisper 模型:

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from whisper import Whisper
    
    # 加载预训练模型
    model = Whisper(pretrained=True)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # 微调模型
    for epoch in range(10):
        for batch in train_loader:
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
    
    # 评估模型性能
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / len(test_loader.dataset)
    print(f'Test Loss: {test_loss / len(test_loader)}')
    print(f'Test Accuracy: {accuracy:.2f}%')
    

    请注意,这只是一个简单的示例代码,你可能需要根据你的数据集和需求进行调整。

    评论

报告相同问题?

问题事件

  • 创建了问题 6月11日