Files
enginex-bi_150-asr/sherpa-onnx/fastapi_sherpa.py
2026-04-08 06:41:00 +00:00

520 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import time
import uuid
import datetime
import tempfile
import soundfile as sf
import sherpa_onnx
import traceback
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
os.makedirs("./input", exist_ok=True)
status = "Running"
recognizer = None
device = ""
model_type = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
# 根据名称判断模型类型,比较杂,一共种类的自定义类型包括(针对OfflineRecognizer)
# moonshine
# fire_red
# dolphin_ctc
# paraformer
# telespeech_ctc
# whisper
# sensevoice
# zipformer_ctc
# transducer
# nemo_ctc
# nemo_canary
# wenet_ctc
# 针对OnlineRecognizer只有 zipformer_ctc transducer paraformer nemo_ctc wenet_ctc 四种
def get_asr_model_type(model_name):
# 根据名称判断模型类型以及需要检测的语种任务
# nemo_ctc, nemo_canary, moonshine 目前sherpa-onnx没有中文模型执行英文ASR任务其余模型执行中文ASR
# 所有nemo模型(nemo_ctc, nemo_canary以及transuducer中的nemo模型)均无中文模型
# 英文模型也并非全部大类都支持
# 特殊规则
# zipformer带ctc的才属于zipformer_ctc那一类否则属于transducer类
# nemo也是带上ctc或者canary才属于单独类别否则属于transducer类
# conformer均为transducer类,但是得在nemo之后判断
# wenet 由于同时wenetspeech为数据集名称各种类型都有可能这个逻辑需放在后面
model_type = "unknown"
model_name_lower = model_name.lower()
if "tdnn" in model_name_lower:
model_type = "tdnn" # tdnn类别不适用目前仅有一个模型只能识别希伯来语中的yes/no两种词语
elif "moonshine" in model_name_lower:
model_type = "moonshine"
elif "fire-red" in model_name_lower:
model_type = "fire_red"
elif "dolphin" in model_name_lower:
model_type = "dolphin_ctc"
elif "paraformer" in model_name_lower:
model_type = "paraformer"
elif "telespeech" in model_name_lower:
model_type = "telespeech_ctc"
elif "whisper" in model_name_lower:
model_type = "whisper"
elif "sense-voice" in model_name_lower:
model_type = "sensevoice"
elif "zipformer" in model_name_lower:
if "ctc" in model_name_lower:
model_type = "zipformer_ctc"
else:
model_type = "transducer"
elif "nemo" in model_name_lower:
if "ctc" in model_name_lower:
model_type = "nemo_ctc"
elif "canary" in model_name_lower:
model_type = "nemo_canary"
else:
model_type = "transducer"
elif "conformer" in model_name_lower or "lstm" in model_name_lower:
model_type = "transducer"
elif "wenet" in model_name_lower:
model_type = "wenet_ctc"
else:
model_type = "unknown"
return model_type
@app.on_event("startup")
def load_model():
global status, recognizer, device, model_type
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
_model_type = config.get("model_type", None)
_model_name = config.get("model_name", None)
warmup = config.get("warmup", False)
isOffline = config.get("offline_model", True)
num_threads = config.get("num_threads", 2)
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:0"
else:
device = "cuda:0"
# sherpa-onnx类型繁杂当用户清楚的时候可提供model_type参数抑或是提供完整的模型名称也行
# 因为挂载进入镜像的时候镜像内的文件路径不一定包含了模型名称
if _model_type:
model_type = _model_type
elif _model_name:
model_type = get_asr_model_type(_model_name)
else:
print("model_name and model_type both not provided, start guessing using model_dir", flush=True)
model_name = os.path.basename(model_dir)
model_type = get_asr_model_type(model_name)
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" model_type =", model_type, flush=True)
print(" use_gpu =", use_gpu, flush=True)
print(" warmup =", warmup, flush=True)
print(" isOffline =", isOffline, flush=True)
print(" num_threads =", num_threads, flush=True)
try:
recognizer = None
provider = "cuda" if use_gpu else "cpu"
file_list = os.listdir(model_dir)
# 目录内的模型文件可能会有多套(例如量化和不带量化版),选取大小最大的那一套
if model_type == "whisper":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
language="zh",
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "sensevoice":
model_list = []
tokens = ""
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
use_itn=True,
language="zh",
provider=provider,
num_threads=num_threads
)
elif model_type == "paraformer":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
paraformer=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "zipformer_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_zipformer_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "telespeech_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "fire_red":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "wenet_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "dolphin_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_dolphin_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "transducer":
encoder_list, decoder_list, joiner_list = [], [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "joiner" in file and file.endswith(".onnx"):
joiner_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
joiner = sorted(joiner_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
# 特殊情况zipformer,conformer都是icefall导出默认类型即可nemo-transducer需要专门区分
transducer_type = "nemo_transducer" if "nemo" in model_name.lower() else "transducer"
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
joiner=model_dir + "/" + joiner,
tokens=model_dir + "/" + tokens,
model_type=transducer_type,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
joiner=model_dir + "/" + joiner,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "nemo_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
if isOffline:
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
recognizer = sherpa_onnx.OnlineRecognizer.from_nemo_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "nemo_canary":
encoder_list, decoder_list = [], []
tokens = ""
for file in file_list:
if "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "decode" in file and file.endswith(".onnx"):
decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_canary(
encoder=model_dir + "/" + encoder,
decoder=model_dir + "/" + decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "moonshine":
preprocessor_list, encoder_list, cached_decoder_list, uncached_decoder_list = [], [], [], []
tokens = ""
for file in file_list:
if "preprocess" in file and file.endswith(".onnx"):
preprocessor_list.append(file)
elif "encode" in file and file.endswith(".onnx"):
encoder_list.append(file)
elif "uncached_decode" in file and file.endswith(".onnx"):
uncached_decoder_list.append(file)
elif "cached_decode" in file and file.endswith(".onnx"):
cached_decoder_list.append(file)
elif "token" in file and file.endswith(".txt"):
tokens = file
preprocessor = sorted(preprocessor_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
cached_decoder = sorted(cached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
uncached_decoder = sorted(uncached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=model_dir + "/" + preprocessor,
encoder=model_dir + "/" + encoder,
cached_decoder=model_dir + "/" + cached_decoder,
uncached_decoder=model_dir + "/" + uncached_decoder,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
elif model_type == "tdnn_ctc":
model_list = []
for file in file_list:
if file.endswith(".onnx"):
model_list.append(file)
elif file.endswith(".txt") and "token" in file:
tokens = file
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=model_dir + "/" + model,
tokens=model_dir + "/" + tokens,
debug=False,
provider=provider,
num_threads=num_threads
)
else:
raise RuntimeError("Cannot recognize model_type")
except Exception as e:
raise RuntimeError(f"Failed to initial cuda model: {e}")
if warmup:
print("Start warmup...", flush=True)
stream = recognizer.create_stream()
audio, sample_rate = sf.read("warmup.wav", dtype="float32", always_2d=True)
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print("warmup complete.", flush=True)
status = "Success"
def test_sherpa(wavefile):
isOffline = app.state.config.get("offline_model", True)
audio, sample_rate = sf.read(wavefile, dtype="float32", always_2d=True)
audio = audio[:, 0]
generated_text = ""
start_t = datetime.datetime.now()
if isOffline:
# OfflineRecognizer非流式模型推理
if model_type in ["sensevoice"]:
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
generated_text = stream.result.text
else:
# offline-asr model 大多对长音频支持不佳模型训练音频不长以及导出onnx结构中对一些中间态维度可能有上限
# 哪怕原版CPU推理中间可能都会崩溃采取小段切分形式测试
start_index = 0
internal = int(sample_rate * 29)
while start_index < len(audio):
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio[start_index:start_index + internal])
recognizer.decode_stream(stream)
generated_text += stream.result.text
start_index += internal
else:
# OnlineRecognizer流式模型推理,统一每一次只投喂2s音频数据
stream = recognizer.create_stream()
start_index = 0
chunk_size = int(sample_rate * 2)
while start_index < len(audio):
chunk = audio[start_index:start_index + chunk_size]
stream.accept_waveform(sample_rate, chunk)
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
# mid_text = recognizer.get_result(stream)
# print("partial result: " + mid_text, flush=True)
start_index += chunk_size
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
generated_text = recognizer.get_result(stream)
end_t = datetime.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = audio.shape[-1] / sample_rate
rtf = elapsed_seconds / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{elapsed_seconds:.3f} s", flush=True)
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status=="Running":
return {
"status":"loading model"
}
ret = {
"status": "ok" if status == "Success" else "failed",
}
return ret
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = test_sherpa(file_path)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_sherpa:app", host="0.0.0.0", port=1111, workers=1)