Files
enginex-bi_150-asr/sherpa-onnx/fastapi_sherpa.py

520 lines
22 KiB
Python
Raw Permalink Normal View History

2026-04-08 06:41:00 +00:00
import os
import sys
import time
import uuid
import datetime
import tempfile
import soundfile as sf
import sherpa_onnx
import traceback
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
os.makedirs("./input", exist_ok=True)
status = "Running"
recognizer = None
device = ""
model_type = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
# 根据名称判断模型类型,比较杂,一共种类的自定义类型包括(针对OfflineRecognizer)
# moonshine
# fire_red
# dolphin_ctc
# paraformer
# telespeech_ctc
# whisper
# sensevoice
# zipformer_ctc
# transducer
# nemo_ctc
# nemo_canary
# wenet_ctc
# 针对OnlineRecognizer只有 zipformer_ctc transducer paraformer nemo_ctc wenet_ctc 四种
def get_asr_model_type(model_name):
# 根据名称判断模型类型以及需要检测的语种任务
# nemo_ctc, nemo_canary, moonshine 目前sherpa-onnx没有中文模型执行英文ASR任务其余模型执行中文ASR
# 所有nemo模型(nemo_ctc, nemo_canary以及transuducer中的nemo模型)均无中文模型
# 英文模型也并非全部大类都支持
# 特殊规则
# zipformer带ctc的才属于zipformer_ctc那一类否则属于transducer类
# nemo也是带上ctc或者canary才属于单独类别否则属于transducer类
# conformer均为transducer类,但是得在nemo之后判断
# wenet 由于同时wenetspeech为数据集名称各种类型都有可能这个逻辑需放在后面
model_type = "unknown"
model_name_lower = model_name.lower()
if "tdnn" in model_name_lower:
model_type = "tdnn" # tdnn类别不适用目前仅有一个模型只能识别希伯来语中的yes/no两种词语
elif "moonshine" in model_name_lower:
model_type = "moonshine"
elif "fire-red" in model_name_lower:
model_type = "fire_red"
elif "dolphin" in model_name_lower:
model_type = "dolphin_ctc"
elif "paraformer" in model_name_lower:
model_type = "paraformer"
elif "telespeech" in model_name_lower:
model_type = "telespeech_ctc"
elif "whisper" in model_name_lower:
model_type = "whisper"
elif "sense-voice" in model_name_lower:
model_type = "sensevoice"
elif "zipformer" in model_name_lower:
if "ctc" in model_name_lower:
model_type = "zipformer_ctc"
else:
model_type = "transducer"
elif "nemo" in model_name_lower:
if "ctc" in model_name_lower:
model_type = "nemo_ctc"
elif "canary" in model_name_lower:
model_type = "nemo_canary"
else:
model_type = "transducer"
elif "conformer" in model_name_lower or "lstm" in model_name_lower:
model_type = "transducer"
elif "wenet" in model_name_lower:
model_type = "wenet_ctc"
else:
model_type = "unknown"
return model_type
@app.on_event("startup")
def load_model():
global status, recognizer, device, model_type
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
_model_type = config.get("model_type", None)
_model_name = config.get("model_name", None)
warmup = config.get("warmup", False)
isOffline = config.get("offline_model", True)
num_threads = config.get("num_threads", 2)
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
# sherpa-onnx类型繁杂当用户清楚的时候可提供model_type参数抑或是提供完整的模型名称也行
# 因为挂载进入镜像的时候镜像内的文件路径不一定包含了模型名称
if _model_type:
model_type = _model_type
elif _model_name:
model_type = get_asr_model_type(_model_name)
else:
print("model_name and model_type both not provided, start guessing using model_dir", flush=True)
model_name = os.path.basename(model_dir)
model_type = get_asr_model_type(model_name)
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" model_type =", model_type, flush=True)
print(" use_gpu =", use_gpu, flush=True)
print(" warmup =", warmup, flush=True)
print(" isOffline =", isOffline, flush=True)
print(" num_threads =", num_threads, flush=True)
try:
recognizer = None
provider = "cuda" if use_gpu else "cpu"
file_list = os.listdir(model_dir)
# 目录内的模型文件可能会有多套(例如量化和不带量化版),选取大小最大的那一套
if model_type == "whisper":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
language="zh",
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "sensevoice":
model_list = []
tokens = ""
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
use_itn=True,
language="zh",
provider=provider,
num_threads=num_threads
)
elif model_type == "paraformer":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
paraformer=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "zipformer_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_zipformer_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "telespeech_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "fire_red":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "wenet_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "dolphin_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_dolphin_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "transducer":
encoder_list, decoder_list, joiner_list = [], [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "joiner" in file and file.endswith(".onnx"):
joiner_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
joiner = sorted(joiner_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
# 特殊情况zipformer,conformer都是icefall导出默认类型即可nemo-transducer需要专门区分
transducer_type = "nemo_transducer" if "nemo" in model_name.lower() else "transducer"
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
joiner=model_dir + "/" + joiner,
tokens=model_dir + "/" + tokens,
model_type=transducer_type,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
joiner=model_dir + "/" + joiner,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "nemo_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_nemo_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "nemo_canary":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_canary(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "moonshine":
preprocessor_list, encoder_list, cached_decoder_list, uncached_decoder_list = [], [], [], []
tokens = ""
for file in file_list:
if "preprocess" in file and file.endswith(".onnx"):
preprocessor_list.append(file)
elif "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "uncached_decode" in file and file.endswith(".onnx"):
uncached_decoder_list.append(file)
elif "cached_decode" in file and file.endswith(".onnx"):
cached_decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
preprocessor = sorted(preprocessor_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
cached_decoder = sorted(cached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
uncached_decoder = sorted(uncached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=model_dir + "/" + preprocessor,
encoder=model_dir + "/" + encoder,
cached_decoder=model_dir + "/" + cached_decoder,
uncached_decoder=model_dir + "/" + uncached_decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "tdnn_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
raise RuntimeError("Cannot recognize model_type")
except Exception as e:
raise RuntimeError(f"Failed to initial cuda model: {e}")
if warmup:
print("Start warmup...", flush=True)
stream = recognizer.create_stream()
audio, sample_rate = sf.read("warmup.wav", dtype="float32", always_2d=True)
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print("warmup complete.", flush=True)
status = "Success"
def test_sherpa(wavefile):
isOffline = app.state.config.get("offline_model", True)
audio, sample_rate = sf.read(wavefile, dtype="float32", always_2d=True)
audio = audio[:, 0]
generated_text = ""
start_t = datetime.datetime.now()
if isOffline:
# OfflineRecognizer非流式模型推理
if model_type in ["sensevoice"]:
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
generated_text = stream.result.text
else:
# offline-asr model 大多对长音频支持不佳模型训练音频不长以及导出onnx结构中对一些中间态维度可能有上限
# 哪怕原版CPU推理中间可能都会崩溃采取小段切分形式测试
start_index = 0
internal = int(sample_rate * 29)
while start_index < len(audio):
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio[start_index:start_index + internal])
recognizer.decode_stream(stream)
generated_text += stream.result.text
start_index += internal
else:
# OnlineRecognizer流式模型推理,统一每一次只投喂2s音频数据
stream = recognizer.create_stream()
start_index = 0
chunk_size = int(sample_rate * 2)
while start_index < len(audio):
chunk = audio[start_index:start_index + chunk_size]
stream.accept_waveform(sample_rate, chunk)
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
# mid_text = recognizer.get_result(stream)
# print("partial result: " + mid_text, flush=True)
start_index += chunk_size
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
generated_text = recognizer.get_result(stream)
end_t = datetime.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = audio.shape[-1] / sample_rate
rtf = elapsed_seconds / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{elapsed_seconds:.3f} s", flush=True)
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status=="Running":
return {
"status":"loading model"
}
ret = {
"status": "ok" if status == "Success" else "failed",
}
return ret
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = test_sherpa(file_path)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_sherpa:app", host="0.0.0.0", port=1111, workers=1)