Files
enginex-bi_150-asr/transformers/fastapi_transformers.py
2026-04-08 06:41:00 +00:00

233 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import uuid
import json
import inspect
import traceback
import numpy as np
import torch
import torchaudio
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
import uvicorn
from transformers import pipeline as hf_pipeline
os.makedirs("./input", exist_ok=True)
status = "Running"
asr_pipeline = None
is_whisper = False # 唯一需要区分的分支Whisper(seq2seq) vs 其余所有 CTC 类模型
device = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
class _SamplingRateCompatProxy:
"""为非标准 FeatureExtractor 提供兼容性包装。
transformers pipeline 的 preprocess 固定会向 feature_extractor 传 sampling_rate、
return_tensors 等标准 kwargs但部分模型如 GraniteSpeech的 FeatureExtractor
没有实现这些参数。此代理在初始化时检查签名,调用时只转发 FeatureExtractor 实际接受的参数。
调用前须确保音频已按模型期望采样率重采样完毕run_asr 中已完成)。
"""
def __init__(self, fe):
object.__setattr__(self, "_fe", fe)
# 初始化时检查一次签名,确定接受哪些参数
try:
sig = inspect.signature(fe.__call__)
has_var_kw = any(
p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig.parameters.values()
)
accepted = None if has_var_kw else set(sig.parameters.keys()) - {"self"}
except Exception:
accepted = None # 无法检测时不过滤
object.__setattr__(self, "_accepted", accepted)
def __call__(self, *args, **kwargs):
accepted = object.__getattribute__(self, "_accepted")
if accepted is not None:
kwargs = {k: v for k, v in kwargs.items() if k in accepted}
return object.__getattribute__(self, "_fe")(*args, **kwargs)
def __getattr__(self, name):
return getattr(object.__getattribute__(self, "_fe"), name)
def __setattr__(self, name, value):
setattr(object.__getattribute__(self, "_fe"), name, value)
def _check_is_whisper(model_dir: str, model_type_override: str = None) -> bool:
"""判断是否为 Whisper 架构。
优先使用用户显式传入的 model_type_override
否则读 config.json 中的 model_type 字段(所有 whisper fine-tuned 模型均有此字段)。
"""
if model_type_override:
return model_type_override.lower() == "whisper"
config_path = os.path.join(model_dir, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
cfg = json.load(f)
return cfg.get("model_type", "").lower() == "whisper"
# config.json 不存在时,从目录名做最后兜底
return "whisper" in os.path.basename(model_dir).lower()
@app.on_event("startup")
def load_model():
global status, asr_pipeline, is_whisper, device
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
model_type_override = config.get("model_type", None) # 可选,仅用于覆盖自动判断
warmup = config.get("warmup", False)
use_fp16 = config.get("fp16", False) # 默认 fp32需要用户显式开启
# 与 fastapi_funasr.py 保持一致的设备字符串逻辑,直接传字符串给 pipeline
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
is_whisper = _check_is_whisper(model_dir, model_type_override)
# 默认 fp32跨平台兼容性最好且不影响精度对比
# fp16 需要用户显式开启(--fp16且应确认当前硬件支持
torch_dtype = torch.float16 if use_fp16 else torch.float32
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" is_whisper =", is_whisper, flush=True)
print(" device =", device, flush=True)
print(" torch_dtype =", torch_dtype, flush=True)
print(" chunk_length_s =", app.state.config.get("chunk_length_s", 30), flush=True)
print(" warmup =", warmup, flush=True)
# transformers pipeline 直接接受设备字符串("cpu"/"cuda:0"/"mlu:0"/"npu:0"
# 会自动读取 config.json 实例化正确的模型类,无需手动指定架构
# 注意:不在 pipeline 构建时传 chunk_length_s由 run_asr 自行分片后逐段调用
# 原因:部分模型(如 GraniteSpeech的 FeatureExtractor 不接受 sampling_rate 参数,
# 而 pipeline 内部的 chunk_iter 固定会传该参数,导致报错
asr_pipeline = hf_pipeline(
task="automatic-speech-recognition",
model=model_dir,
device=device,
torch_dtype=torch_dtype,
)
# 检查 feature extractor 是否接受 sampling_rate 参数
# pipeline 的 preprocess 固定会传此参数(硬编码行为),不接受的模型需要代理包装
try:
sig = inspect.signature(asr_pipeline.feature_extractor.__call__)
accepts_sr = "sampling_rate" in sig.parameters or any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
except Exception:
accepts_sr = True # 无法检测时保守假设接受
if not accepts_sr:
asr_pipeline.feature_extractor = _SamplingRateCompatProxy(asr_pipeline.feature_extractor)
print(" Note: FeatureExtractor does not accept sampling_rate, applied compat proxy", flush=True)
if warmup:
print("Start warmup...", flush=True)
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
# 极少数非标准模型可能没有,兜底用 16000ASR 领域最通用的标准采样率)
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
dummy = np.zeros(target_sr, dtype=np.float32) # 1 秒静音
asr_pipeline(dummy, **_build_infer_kwargs("zh"))
print("warmup complete.", flush=True)
status = "Success"
def _build_infer_kwargs(lang: str) -> dict:
"""Whisper 推理时需要额外传语言参数CTC 类无需额外参数。
不再传 return_timestamps因为我们自行分片后逐段调用 pipeline无需 pipeline 内部拼接。
"""
if is_whisper:
return {"generate_kwargs": {"language": lang, "task": "transcribe"}}
return {}
def run_asr(audio_file: str, lang: str) -> str:
waveform, sample_rate = torchaudio.load(audio_file)
duration = waveform.shape[1] / sample_rate
# 多声道转单声道
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# 提前重采样到模型期望的采样率
# 传 numpy array非 dict给 pipeline跳过 pipeline 内部的 sampling_rate 传递逻辑,
# 规避部分模型(如 GraniteSpeech的 FeatureExtractor 不接受 sampling_rate 参数的问题
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
# 极少数非标准模型可能没有,兜底用 16000ASR 领域最通用的标准采样率)
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
if sample_rate != target_sr:
resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
waveform = resampler(waveform)
audio_array = waveform.squeeze(0).numpy().astype(np.float32)
chunk_length_s = app.state.config.get("chunk_length_s", 30)
chunk_samples = chunk_length_s * target_sr
infer_kwargs = _build_infer_kwargs(lang)
ts1 = time.time()
texts = []
for i in range(0, len(audio_array), chunk_samples):
chunk = audio_array[i : i + chunk_samples]
result = asr_pipeline(chunk, **infer_kwargs)
texts.append(result["text"])
ts2 = time.time()
generated_text = "".join(texts)
# wav2vec2 系列模型会用 U+2581 (▁) 作为词间分隔符,替换为空格
generated_text = generated_text.replace("", " ").replace(chr(9601), " ").strip()
processing_time = ts2 - ts1
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status == "Running":
return {"status": "loading model"}
return {"status": "ok" if status == "Success" else "failed"}
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None,
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = run_asr(file_path, lang)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_transformers:app", host="0.0.0.0", port=8000, workers=1)