import os import time import uuid import json import inspect import traceback import numpy as np import torch import torchaudio from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form import uvicorn from transformers import pipeline as hf_pipeline os.makedirs("./input", exist_ok=True) status = "Running" asr_pipeline = None is_whisper = False # 唯一需要区分的分支:Whisper(seq2seq) vs 其余所有 CTC 类模型 device = "" app = FastAPI() CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "") if CUSTOM_DEVICE.startswith("mlu"): import torch_mlu elif CUSTOM_DEVICE.startswith("ascend"): import torch_npu elif CUSTOM_DEVICE.startswith("pt"): import torch_dipu class _SamplingRateCompatProxy: """为非标准 FeatureExtractor 提供兼容性包装。 transformers pipeline 的 preprocess 固定会向 feature_extractor 传 sampling_rate、 return_tensors 等标准 kwargs,但部分模型(如 GraniteSpeech)的 FeatureExtractor 没有实现这些参数。此代理在初始化时检查签名,调用时只转发 FeatureExtractor 实际接受的参数。 调用前须确保音频已按模型期望采样率重采样完毕(run_asr 中已完成)。 """ def __init__(self, fe): object.__setattr__(self, "_fe", fe) # 初始化时检查一次签名,确定接受哪些参数 try: sig = inspect.signature(fe.__call__) has_var_kw = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) accepted = None if has_var_kw else set(sig.parameters.keys()) - {"self"} except Exception: accepted = None # 无法检测时不过滤 object.__setattr__(self, "_accepted", accepted) def __call__(self, *args, **kwargs): accepted = object.__getattribute__(self, "_accepted") if accepted is not None: kwargs = {k: v for k, v in kwargs.items() if k in accepted} return object.__getattribute__(self, "_fe")(*args, **kwargs) def __getattr__(self, name): return getattr(object.__getattribute__(self, "_fe"), name) def __setattr__(self, name, value): setattr(object.__getattribute__(self, "_fe"), name, value) def _check_is_whisper(model_dir: str, model_type_override: str = None) -> bool: """判断是否为 Whisper 架构。 优先使用用户显式传入的 model_type_override, 否则读 config.json 中的 model_type 字段(所有 whisper fine-tuned 模型均有此字段)。 """ if model_type_override: return model_type_override.lower() == "whisper" config_path = os.path.join(model_dir, "config.json") if os.path.exists(config_path): with open(config_path, "r") as f: cfg = json.load(f) return cfg.get("model_type", "").lower() == "whisper" # config.json 不存在时,从目录名做最后兜底 return "whisper" in os.path.basename(model_dir).lower() @app.on_event("startup") def load_model(): global status, asr_pipeline, is_whisper, device config = app.state.config use_gpu = config.get("use_gpu", True) model_dir = config.get("model_dir", "/model") model_type_override = config.get("model_type", None) # 可选,仅用于覆盖自动判断 warmup = config.get("warmup", False) use_fp16 = config.get("fp16", False) # 默认 fp32,需要用户显式开启 # 与 fastapi_funasr.py 保持一致的设备字符串逻辑,直接传字符串给 pipeline device = "cpu" if use_gpu: if CUSTOM_DEVICE.startswith("mlu"): device = "mlu:0" elif CUSTOM_DEVICE.startswith("ascend"): device = "npu:0" else: device = "cuda:0" is_whisper = _check_is_whisper(model_dir, model_type_override) # 默认 fp32,跨平台兼容性最好且不影响精度对比 # fp16 需要用户显式开启(--fp16),且应确认当前硬件支持 torch_dtype = torch.float16 if use_fp16 else torch.float32 print(">> Startup config:") print(" model_dir =", model_dir, flush=True) print(" is_whisper =", is_whisper, flush=True) print(" device =", device, flush=True) print(" torch_dtype =", torch_dtype, flush=True) print(" chunk_length_s =", app.state.config.get("chunk_length_s", 30), flush=True) print(" warmup =", warmup, flush=True) # transformers pipeline 直接接受设备字符串("cpu"/"cuda:0"/"mlu:0"/"npu:0") # 会自动读取 config.json 实例化正确的模型类,无需手动指定架构 # 注意:不在 pipeline 构建时传 chunk_length_s,由 run_asr 自行分片后逐段调用 # 原因:部分模型(如 GraniteSpeech)的 FeatureExtractor 不接受 sampling_rate 参数, # 而 pipeline 内部的 chunk_iter 固定会传该参数,导致报错 asr_pipeline = hf_pipeline( task="automatic-speech-recognition", model=model_dir, device=device, torch_dtype=torch_dtype, ) # 检查 feature extractor 是否接受 sampling_rate 参数 # pipeline 的 preprocess 固定会传此参数(硬编码行为),不接受的模型需要代理包装 try: sig = inspect.signature(asr_pipeline.feature_extractor.__call__) accepts_sr = "sampling_rate" in sig.parameters or any( p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) except Exception: accepts_sr = True # 无法检测时保守假设接受 if not accepts_sr: asr_pipeline.feature_extractor = _SamplingRateCompatProxy(asr_pipeline.feature_extractor) print(" Note: FeatureExtractor does not accept sampling_rate, applied compat proxy", flush=True) if warmup: print("Start warmup...", flush=True) # 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性 # 极少数非标准模型可能没有,兜底用 16000(ASR 领域最通用的标准采样率) target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000) dummy = np.zeros(target_sr, dtype=np.float32) # 1 秒静音 asr_pipeline(dummy, **_build_infer_kwargs("zh")) print("warmup complete.", flush=True) status = "Success" def _build_infer_kwargs(lang: str) -> dict: """Whisper 推理时需要额外传语言参数;CTC 类无需额外参数。 不再传 return_timestamps,因为我们自行分片后逐段调用 pipeline,无需 pipeline 内部拼接。 """ if is_whisper: return {"generate_kwargs": {"language": lang, "task": "transcribe"}} return {} def run_asr(audio_file: str, lang: str) -> str: waveform, sample_rate = torchaudio.load(audio_file) duration = waveform.shape[1] / sample_rate # 多声道转单声道 if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # 提前重采样到模型期望的采样率 # 传 numpy array(非 dict)给 pipeline,跳过 pipeline 内部的 sampling_rate 传递逻辑, # 规避部分模型(如 GraniteSpeech)的 FeatureExtractor 不接受 sampling_rate 参数的问题 # 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性 # 极少数非标准模型可能没有,兜底用 16000(ASR 领域最通用的标准采样率) target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000) if sample_rate != target_sr: resampler = torchaudio.transforms.Resample(sample_rate, target_sr) waveform = resampler(waveform) audio_array = waveform.squeeze(0).numpy().astype(np.float32) chunk_length_s = app.state.config.get("chunk_length_s", 30) chunk_samples = chunk_length_s * target_sr infer_kwargs = _build_infer_kwargs(lang) ts1 = time.time() texts = [] for i in range(0, len(audio_array), chunk_samples): chunk = audio_array[i : i + chunk_samples] result = asr_pipeline(chunk, **infer_kwargs) texts.append(result["text"]) ts2 = time.time() generated_text = "".join(texts) # wav2vec2 系列模型会用 U+2581 (▁) 作为词间分隔符,替换为空格 generated_text = generated_text.replace("▁", " ").replace(chr(9601), " ").strip() processing_time = ts2 - ts1 rtf = processing_time / duration print("Text:", generated_text, flush=True) print(f"Audio duration:\t{duration:.3f} s", flush=True) print(f"Elapsed:\t{processing_time:.3f} s", flush=True) print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True) return generated_text @app.get("/health") def health(): if status == "Running": return {"status": "loading model"} return {"status": "ok" if status == "Success" else "failed"} @app.post("/transduce") def transduce( audio: UploadFile = File(...), lang: str = Form("zh"), background_tasks: BackgroundTasks = None, ): try: file_path = f"./input/{uuid.uuid4()}.wav" with open(file_path, "wb") as f: f.write(audio.file.read()) background_tasks.add_task(os.remove, file_path) generated_text = run_asr(file_path, lang) return {"generated_text": generated_text} except Exception: raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}") # if __name__ == "__main__": # uvicorn.run("fastapi_transformers:app", host="0.0.0.0", port=8000, workers=1)