233 lines
9.4 KiB
Python
233 lines
9.4 KiB
Python
import os
|
||
import time
|
||
import uuid
|
||
import json
|
||
import inspect
|
||
import traceback
|
||
import numpy as np
|
||
import torch
|
||
import torchaudio
|
||
|
||
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
|
||
import uvicorn
|
||
from transformers import pipeline as hf_pipeline
|
||
|
||
os.makedirs("./input", exist_ok=True)
|
||
status = "Running"
|
||
asr_pipeline = None
|
||
is_whisper = False # 唯一需要区分的分支:Whisper(seq2seq) vs 其余所有 CTC 类模型
|
||
device = ""
|
||
app = FastAPI()
|
||
|
||
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
|
||
if CUSTOM_DEVICE.startswith("mlu"):
|
||
import torch_mlu
|
||
elif CUSTOM_DEVICE.startswith("ascend"):
|
||
import torch_npu
|
||
elif CUSTOM_DEVICE.startswith("pt"):
|
||
import torch_dipu
|
||
|
||
|
||
class _SamplingRateCompatProxy:
|
||
"""为非标准 FeatureExtractor 提供兼容性包装。
|
||
|
||
transformers pipeline 的 preprocess 固定会向 feature_extractor 传 sampling_rate、
|
||
return_tensors 等标准 kwargs,但部分模型(如 GraniteSpeech)的 FeatureExtractor
|
||
没有实现这些参数。此代理在初始化时检查签名,调用时只转发 FeatureExtractor 实际接受的参数。
|
||
调用前须确保音频已按模型期望采样率重采样完毕(run_asr 中已完成)。
|
||
"""
|
||
def __init__(self, fe):
|
||
object.__setattr__(self, "_fe", fe)
|
||
# 初始化时检查一次签名,确定接受哪些参数
|
||
try:
|
||
sig = inspect.signature(fe.__call__)
|
||
has_var_kw = any(
|
||
p.kind == inspect.Parameter.VAR_KEYWORD
|
||
for p in sig.parameters.values()
|
||
)
|
||
accepted = None if has_var_kw else set(sig.parameters.keys()) - {"self"}
|
||
except Exception:
|
||
accepted = None # 无法检测时不过滤
|
||
object.__setattr__(self, "_accepted", accepted)
|
||
|
||
def __call__(self, *args, **kwargs):
|
||
accepted = object.__getattribute__(self, "_accepted")
|
||
if accepted is not None:
|
||
kwargs = {k: v for k, v in kwargs.items() if k in accepted}
|
||
return object.__getattribute__(self, "_fe")(*args, **kwargs)
|
||
|
||
def __getattr__(self, name):
|
||
return getattr(object.__getattribute__(self, "_fe"), name)
|
||
|
||
def __setattr__(self, name, value):
|
||
setattr(object.__getattribute__(self, "_fe"), name, value)
|
||
|
||
|
||
def _check_is_whisper(model_dir: str, model_type_override: str = None) -> bool:
|
||
"""判断是否为 Whisper 架构。
|
||
优先使用用户显式传入的 model_type_override,
|
||
否则读 config.json 中的 model_type 字段(所有 whisper fine-tuned 模型均有此字段)。
|
||
"""
|
||
if model_type_override:
|
||
return model_type_override.lower() == "whisper"
|
||
config_path = os.path.join(model_dir, "config.json")
|
||
if os.path.exists(config_path):
|
||
with open(config_path, "r") as f:
|
||
cfg = json.load(f)
|
||
return cfg.get("model_type", "").lower() == "whisper"
|
||
# config.json 不存在时,从目录名做最后兜底
|
||
return "whisper" in os.path.basename(model_dir).lower()
|
||
|
||
|
||
@app.on_event("startup")
|
||
def load_model():
|
||
global status, asr_pipeline, is_whisper, device
|
||
|
||
config = app.state.config
|
||
use_gpu = config.get("use_gpu", True)
|
||
model_dir = config.get("model_dir", "/model")
|
||
model_type_override = config.get("model_type", None) # 可选,仅用于覆盖自动判断
|
||
warmup = config.get("warmup", False)
|
||
use_fp16 = config.get("fp16", False) # 默认 fp32,需要用户显式开启
|
||
|
||
# 与 fastapi_funasr.py 保持一致的设备字符串逻辑,直接传字符串给 pipeline
|
||
device = "cpu"
|
||
if use_gpu:
|
||
if CUSTOM_DEVICE.startswith("mlu"):
|
||
device = "mlu:0"
|
||
elif CUSTOM_DEVICE.startswith("ascend"):
|
||
device = "npu:0"
|
||
else:
|
||
device = "cuda:0"
|
||
|
||
is_whisper = _check_is_whisper(model_dir, model_type_override)
|
||
# 默认 fp32,跨平台兼容性最好且不影响精度对比
|
||
# fp16 需要用户显式开启(--fp16),且应确认当前硬件支持
|
||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||
|
||
print(">> Startup config:")
|
||
print(" model_dir =", model_dir, flush=True)
|
||
print(" is_whisper =", is_whisper, flush=True)
|
||
print(" device =", device, flush=True)
|
||
print(" torch_dtype =", torch_dtype, flush=True)
|
||
print(" chunk_length_s =", app.state.config.get("chunk_length_s", 30), flush=True)
|
||
print(" warmup =", warmup, flush=True)
|
||
|
||
# transformers pipeline 直接接受设备字符串("cpu"/"cuda:0"/"mlu:0"/"npu:0")
|
||
# 会自动读取 config.json 实例化正确的模型类,无需手动指定架构
|
||
# 注意:不在 pipeline 构建时传 chunk_length_s,由 run_asr 自行分片后逐段调用
|
||
# 原因:部分模型(如 GraniteSpeech)的 FeatureExtractor 不接受 sampling_rate 参数,
|
||
# 而 pipeline 内部的 chunk_iter 固定会传该参数,导致报错
|
||
asr_pipeline = hf_pipeline(
|
||
task="automatic-speech-recognition",
|
||
model=model_dir,
|
||
device=device,
|
||
torch_dtype=torch_dtype,
|
||
)
|
||
|
||
# 检查 feature extractor 是否接受 sampling_rate 参数
|
||
# pipeline 的 preprocess 固定会传此参数(硬编码行为),不接受的模型需要代理包装
|
||
try:
|
||
sig = inspect.signature(asr_pipeline.feature_extractor.__call__)
|
||
accepts_sr = "sampling_rate" in sig.parameters or any(
|
||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||
)
|
||
except Exception:
|
||
accepts_sr = True # 无法检测时保守假设接受
|
||
if not accepts_sr:
|
||
asr_pipeline.feature_extractor = _SamplingRateCompatProxy(asr_pipeline.feature_extractor)
|
||
print(" Note: FeatureExtractor does not accept sampling_rate, applied compat proxy", flush=True)
|
||
|
||
if warmup:
|
||
print("Start warmup...", flush=True)
|
||
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
|
||
# 极少数非标准模型可能没有,兜底用 16000(ASR 领域最通用的标准采样率)
|
||
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
|
||
dummy = np.zeros(target_sr, dtype=np.float32) # 1 秒静音
|
||
asr_pipeline(dummy, **_build_infer_kwargs("zh"))
|
||
print("warmup complete.", flush=True)
|
||
|
||
status = "Success"
|
||
|
||
|
||
def _build_infer_kwargs(lang: str) -> dict:
|
||
"""Whisper 推理时需要额外传语言参数;CTC 类无需额外参数。
|
||
不再传 return_timestamps,因为我们自行分片后逐段调用 pipeline,无需 pipeline 内部拼接。
|
||
"""
|
||
if is_whisper:
|
||
return {"generate_kwargs": {"language": lang, "task": "transcribe"}}
|
||
return {}
|
||
|
||
|
||
def run_asr(audio_file: str, lang: str) -> str:
|
||
waveform, sample_rate = torchaudio.load(audio_file)
|
||
duration = waveform.shape[1] / sample_rate
|
||
|
||
# 多声道转单声道
|
||
if waveform.shape[0] > 1:
|
||
waveform = waveform.mean(dim=0, keepdim=True)
|
||
|
||
# 提前重采样到模型期望的采样率
|
||
# 传 numpy array(非 dict)给 pipeline,跳过 pipeline 内部的 sampling_rate 传递逻辑,
|
||
# 规避部分模型(如 GraniteSpeech)的 FeatureExtractor 不接受 sampling_rate 参数的问题
|
||
# 获取模型期望的采样率,绝大多数模型的 feature extractor 都有此属性
|
||
# 极少数非标准模型可能没有,兜底用 16000(ASR 领域最通用的标准采样率)
|
||
target_sr = getattr(asr_pipeline.feature_extractor, "sampling_rate", 16000)
|
||
if sample_rate != target_sr:
|
||
resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
|
||
waveform = resampler(waveform)
|
||
|
||
audio_array = waveform.squeeze(0).numpy().astype(np.float32)
|
||
|
||
chunk_length_s = app.state.config.get("chunk_length_s", 30)
|
||
chunk_samples = chunk_length_s * target_sr
|
||
infer_kwargs = _build_infer_kwargs(lang)
|
||
|
||
ts1 = time.time()
|
||
texts = []
|
||
for i in range(0, len(audio_array), chunk_samples):
|
||
chunk = audio_array[i : i + chunk_samples]
|
||
result = asr_pipeline(chunk, **infer_kwargs)
|
||
texts.append(result["text"])
|
||
ts2 = time.time()
|
||
|
||
generated_text = "".join(texts)
|
||
# wav2vec2 系列模型会用 U+2581 (▁) 作为词间分隔符,替换为空格
|
||
generated_text = generated_text.replace("▁", " ").replace(chr(9601), " ").strip()
|
||
|
||
processing_time = ts2 - ts1
|
||
rtf = processing_time / duration
|
||
print("Text:", generated_text, flush=True)
|
||
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
||
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
|
||
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
|
||
|
||
return generated_text
|
||
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
if status == "Running":
|
||
return {"status": "loading model"}
|
||
return {"status": "ok" if status == "Success" else "failed"}
|
||
|
||
|
||
@app.post("/transduce")
|
||
def transduce(
|
||
audio: UploadFile = File(...),
|
||
lang: str = Form("zh"),
|
||
background_tasks: BackgroundTasks = None,
|
||
):
|
||
try:
|
||
file_path = f"./input/{uuid.uuid4()}.wav"
|
||
with open(file_path, "wb") as f:
|
||
f.write(audio.file.read())
|
||
background_tasks.add_task(os.remove, file_path)
|
||
generated_text = run_asr(file_path, lang)
|
||
return {"generated_text": generated_text}
|
||
except Exception:
|
||
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
|
||
|
||
# if __name__ == "__main__":
|
||
# uvicorn.run("fastapi_transformers:app", host="0.0.0.0", port=8000, workers=1)
|