import os import sys import time import uuid import datetime import tempfile import soundfile as sf import sherpa_onnx import traceback from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form os.makedirs("./input", exist_ok=True) status = "Running" recognizer = None device = "" model_type = "" app = FastAPI() CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "") # 根据名称判断模型类型,比较杂,一共种类的自定义类型包括(针对OfflineRecognizer) # moonshine # fire_red # dolphin_ctc # paraformer # telespeech_ctc # whisper # sensevoice # zipformer_ctc # transducer # nemo_ctc # nemo_canary # wenet_ctc # 针对OnlineRecognizer只有 zipformer_ctc transducer paraformer nemo_ctc wenet_ctc 四种 def get_asr_model_type(model_name): # 根据名称判断模型类型以及需要检测的语种任务 # nemo_ctc, nemo_canary, moonshine 目前sherpa-onnx没有中文模型,执行英文ASR任务,其余模型执行中文ASR # 所有nemo模型(nemo_ctc, nemo_canary以及transuducer中的nemo模型)均无中文模型 # 英文模型也并非全部大类都支持 # 特殊规则 # zipformer带ctc的才属于zipformer_ctc那一类,否则属于transducer类 # nemo也是带上ctc或者canary才属于单独类别,否则属于transducer类 # conformer均为transducer类,但是得在nemo之后判断 # wenet 由于同时wenetspeech为数据集名称,各种类型都有可能,这个逻辑需放在后面 model_type = "unknown" model_name_lower = model_name.lower() if "tdnn" in model_name_lower: model_type = "tdnn" # tdnn类别不适用,目前仅有一个模型只能识别希伯来语中的yes/no两种词语 elif "moonshine" in model_name_lower: model_type = "moonshine" elif "fire-red" in model_name_lower: model_type = "fire_red" elif "dolphin" in model_name_lower: model_type = "dolphin_ctc" elif "paraformer" in model_name_lower: model_type = "paraformer" elif "telespeech" in model_name_lower: model_type = "telespeech_ctc" elif "whisper" in model_name_lower: model_type = "whisper" elif "sense-voice" in model_name_lower: model_type = "sensevoice" elif "zipformer" in model_name_lower: if "ctc" in model_name_lower: model_type = "zipformer_ctc" else: model_type = "transducer" elif "nemo" in model_name_lower: if "ctc" in model_name_lower: model_type = "nemo_ctc" elif "canary" in model_name_lower: model_type = "nemo_canary" else: model_type = "transducer" elif "conformer" in model_name_lower or "lstm" in model_name_lower: model_type = "transducer" elif "wenet" in model_name_lower: model_type = "wenet_ctc" else: model_type = "unknown" return model_type @app.on_event("startup") def load_model(): global status, recognizer, device, model_type config = app.state.config use_gpu = config.get("use_gpu", True) model_dir = config.get("model_dir", "/model") _model_type = config.get("model_type", None) _model_name = config.get("model_name", None) warmup = config.get("warmup", False) isOffline = config.get("offline_model", True) num_threads = config.get("num_threads", 2) device = "cpu" if use_gpu: if CUSTOM_DEVICE.startswith("mlu"): device = "mlu:0" elif CUSTOM_DEVICE.startswith("ascend"): device = "npu:0" else: device = "cuda:0" # sherpa-onnx类型繁杂,当用户清楚的时候可提供model_type参数,抑或是提供完整的模型名称也行 # 因为挂载进入镜像的时候镜像内的文件路径不一定包含了模型名称 if _model_type: model_type = _model_type elif _model_name: model_type = get_asr_model_type(_model_name) else: print("model_name and model_type both not provided, start guessing using model_dir", flush=True) model_name = os.path.basename(model_dir) model_type = get_asr_model_type(model_name) print(">> Startup config:") print(" model_dir =", model_dir, flush=True) print(" model_type =", model_type, flush=True) print(" use_gpu =", use_gpu, flush=True) print(" warmup =", warmup, flush=True) print(" isOffline =", isOffline, flush=True) print(" num_threads =", num_threads, flush=True) try: recognizer = None provider = "cuda" if use_gpu else "cpu" file_list = os.listdir(model_dir) # 目录内的模型文件可能会有多套(例如量化和不带量化版),选取大小最大的那一套 if model_type == "whisper": encoder_list, decoder_list = [], [] tokens = "" for file in file_list: if "encode" in file and file.endswith(".onnx"): encoder_list.append(file) elif "decode" in file and file.endswith(".onnx"): decoder_list.append(file) elif "token" in file and file.endswith(".txt"): tokens = file encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( encoder=model_dir + "/" + encoder, decoder=model_dir + "/" + decoder, tokens=model_dir + "/" + tokens, language="zh", debug=False, provider=provider, num_threads=num_threads ) elif model_type == "sensevoice": model_list = [] tokens = "" for file in file_list: if file.endswith(".onnx"): model_list.append(file) elif file.endswith(".txt") and "token" in file: tokens = file model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, use_itn=True, language="zh", provider=provider, num_threads=num_threads ) elif model_type == "paraformer": model_list = [] for file in file_list: if file.endswith(".onnx"): model_list.append(file) elif file.endswith(".txt") and "token" in file: tokens = file model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] if isOffline: recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) else: recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( paraformer=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "zipformer_ctc": model_list = [] for file in file_list: if file.endswith(".onnx"): model_list.append(file) elif file.endswith(".txt") and "token" in file: tokens = file model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] if isOffline: recognizer = sherpa_onnx.OfflineRecognizer.from_zipformer_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) else: recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "telespeech_ctc": model_list = [] for file in file_list: if file.endswith(".onnx"): model_list.append(file) elif file.endswith(".txt") and "token" in file: tokens = file model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "fire_red": encoder_list, decoder_list = [], [] tokens = "" for file in file_list: if "encode" in file and file.endswith(".onnx"): encoder_list.append(file) elif "decode" in file and file.endswith(".onnx"): decoder_list.append(file) elif "token" in file and file.endswith(".txt"): tokens = file encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr( encoder=model_dir + "/" + encoder, decoder=model_dir + "/" + decoder, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "wenet_ctc": model_list = [] for file in file_list: if file.endswith(".onnx"): model_list.append(file) elif file.endswith(".txt") and "token" in file: tokens = file model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] if isOffline: recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) else: recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "dolphin_ctc": model_list = [] for file in file_list: if file.endswith(".onnx"): model_list.append(file) elif file.endswith(".txt") and "token" in file: tokens = file model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] recognizer = sherpa_onnx.OfflineRecognizer.from_dolphin_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "transducer": encoder_list, decoder_list, joiner_list = [], [], [] tokens = "" for file in file_list: if "encode" in file and file.endswith(".onnx"): encoder_list.append(file) elif "decode" in file and file.endswith(".onnx"): decoder_list.append(file) elif "joiner" in file and file.endswith(".onnx"): joiner_list.append(file) elif "token" in file and file.endswith(".txt"): tokens = file encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] joiner = sorted(joiner_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] # 特殊情况,zipformer,conformer都是icefall导出,默认类型即可,nemo-transducer需要专门区分 transducer_type = "nemo_transducer" if "nemo" in model_name.lower() else "transducer" if isOffline: recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( encoder=model_dir + "/" + encoder, decoder=model_dir + "/" + decoder, joiner=model_dir + "/" + joiner, tokens=model_dir + "/" + tokens, model_type=transducer_type, debug=False, provider=provider, num_threads=num_threads ) else: recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( encoder=model_dir + "/" + encoder, decoder=model_dir + "/" + decoder, joiner=model_dir + "/" + joiner, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "nemo_ctc": model_list = [] for file in file_list: if file.endswith(".onnx"): model_list.append(file) elif file.endswith(".txt") and "token" in file: tokens = file model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] if isOffline: recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) else: recognizer = sherpa_onnx.OnlineRecognizer.from_nemo_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "nemo_canary": encoder_list, decoder_list = [], [] tokens = "" for file in file_list: if "encode" in file and file.endswith(".onnx"): encoder_list.append(file) elif "decode" in file and file.endswith(".onnx"): decoder_list.append(file) elif "token" in file and file.endswith(".txt"): tokens = file encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_canary( encoder=model_dir + "/" + encoder, decoder=model_dir + "/" + decoder, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "moonshine": preprocessor_list, encoder_list, cached_decoder_list, uncached_decoder_list = [], [], [], [] tokens = "" for file in file_list: if "preprocess" in file and file.endswith(".onnx"): preprocessor_list.append(file) elif "encode" in file and file.endswith(".onnx"): encoder_list.append(file) elif "uncached_decode" in file and file.endswith(".onnx"): uncached_decoder_list.append(file) elif "cached_decode" in file and file.endswith(".onnx"): cached_decoder_list.append(file) elif "token" in file and file.endswith(".txt"): tokens = file preprocessor = sorted(preprocessor_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] cached_decoder = sorted(cached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] uncached_decoder = sorted(uncached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine( preprocessor=model_dir + "/" + preprocessor, encoder=model_dir + "/" + encoder, cached_decoder=model_dir + "/" + cached_decoder, uncached_decoder=model_dir + "/" + uncached_decoder, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) elif model_type == "tdnn_ctc": model_list = [] for file in file_list: if file.endswith(".onnx"): model_list.append(file) elif file.endswith(".txt") and "token" in file: tokens = file model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0] recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( model=model_dir + "/" + model, tokens=model_dir + "/" + tokens, debug=False, provider=provider, num_threads=num_threads ) else: raise RuntimeError("Cannot recognize model_type") except Exception as e: raise RuntimeError(f"Failed to initial cuda model: {e}") if warmup: print("Start warmup...", flush=True) stream = recognizer.create_stream() audio, sample_rate = sf.read("warmup.wav", dtype="float32", always_2d=True) stream.accept_waveform(sample_rate, audio) recognizer.decode_stream(stream) print("warmup complete.", flush=True) status = "Success" def test_sherpa(wavefile): isOffline = app.state.config.get("offline_model", True) audio, sample_rate = sf.read(wavefile, dtype="float32", always_2d=True) audio = audio[:, 0] generated_text = "" start_t = datetime.datetime.now() if isOffline: # OfflineRecognizer非流式模型推理 if model_type in ["sensevoice"]: stream = recognizer.create_stream() stream.accept_waveform(sample_rate, audio) recognizer.decode_stream(stream) generated_text = stream.result.text else: # offline-asr model 大多对长音频支持不佳,模型训练音频不长以及导出onnx结构中对一些中间态维度可能有上限 # 哪怕原版CPU推理中间可能都会崩溃,采取小段切分形式测试 start_index = 0 internal = int(sample_rate * 29) while start_index < len(audio): stream = recognizer.create_stream() stream.accept_waveform(sample_rate, audio[start_index:start_index + internal]) recognizer.decode_stream(stream) generated_text += stream.result.text start_index += internal else: # OnlineRecognizer流式模型推理,统一每一次只投喂2s音频数据 stream = recognizer.create_stream() start_index = 0 chunk_size = int(sample_rate * 2) while start_index < len(audio): chunk = audio[start_index:start_index + chunk_size] stream.accept_waveform(sample_rate, chunk) while recognizer.is_ready(stream): recognizer.decode_stream(stream) # mid_text = recognizer.get_result(stream) # print("partial result: " + mid_text, flush=True) start_index += chunk_size while recognizer.is_ready(stream): recognizer.decode_stream(stream) generated_text = recognizer.get_result(stream) end_t = datetime.datetime.now() elapsed_seconds = (end_t - start_t).total_seconds() duration = audio.shape[-1] / sample_rate rtf = elapsed_seconds / duration print("Text:", generated_text, flush=True) print(f"Audio duration:\t{duration:.3f} s", flush=True) print(f"Elapsed:\t{elapsed_seconds:.3f} s", flush=True) print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}", flush=True) return generated_text @app.get("/health") def health(): if status=="Running": return { "status":"loading model" } ret = { "status": "ok" if status == "Success" else "failed", } return ret @app.post("/transduce") def transduce( audio: UploadFile = File(...), lang: str = Form("zh"), background_tasks: BackgroundTasks = None ): try: file_path = f"./input/{uuid.uuid4()}.wav" with open(file_path, "wb") as f: f.write(audio.file.read()) background_tasks.add_task(os.remove, file_path) generated_text = test_sherpa(file_path) return {"generated_text": generated_text} except Exception: raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}") # if __name__ == "__main__": # uvicorn.run("fastapi_sherpa:app", host="0.0.0.0", port=1111, workers=1)