520 lines
22 KiB
Python
520 lines
22 KiB
Python
import os
|
||
import sys
|
||
import time
|
||
import uuid
|
||
import datetime
|
||
import tempfile
|
||
import soundfile as sf
|
||
import sherpa_onnx
|
||
import traceback
|
||
|
||
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
|
||
|
||
os.makedirs("./input", exist_ok=True)
|
||
status = "Running"
|
||
recognizer = None
|
||
device = ""
|
||
model_type = ""
|
||
app = FastAPI()
|
||
|
||
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
|
||
|
||
# 根据名称判断模型类型,比较杂,一共种类的自定义类型包括(针对OfflineRecognizer)
|
||
# moonshine
|
||
# fire_red
|
||
# dolphin_ctc
|
||
# paraformer
|
||
# telespeech_ctc
|
||
# whisper
|
||
# sensevoice
|
||
# zipformer_ctc
|
||
# transducer
|
||
# nemo_ctc
|
||
# nemo_canary
|
||
# wenet_ctc
|
||
# 针对OnlineRecognizer只有 zipformer_ctc transducer paraformer nemo_ctc wenet_ctc 四种
|
||
def get_asr_model_type(model_name):
|
||
# 根据名称判断模型类型以及需要检测的语种任务
|
||
# nemo_ctc, nemo_canary, moonshine 目前sherpa-onnx没有中文模型,执行英文ASR任务,其余模型执行中文ASR
|
||
# 所有nemo模型(nemo_ctc, nemo_canary以及transuducer中的nemo模型)均无中文模型
|
||
# 英文模型也并非全部大类都支持
|
||
|
||
# 特殊规则
|
||
# zipformer带ctc的才属于zipformer_ctc那一类,否则属于transducer类
|
||
# nemo也是带上ctc或者canary才属于单独类别,否则属于transducer类
|
||
# conformer均为transducer类,但是得在nemo之后判断
|
||
# wenet 由于同时wenetspeech为数据集名称,各种类型都有可能,这个逻辑需放在后面
|
||
model_type = "unknown"
|
||
model_name_lower = model_name.lower()
|
||
if "tdnn" in model_name_lower:
|
||
model_type = "tdnn" # tdnn类别不适用,目前仅有一个模型只能识别希伯来语中的yes/no两种词语
|
||
elif "moonshine" in model_name_lower:
|
||
model_type = "moonshine"
|
||
elif "fire-red" in model_name_lower:
|
||
model_type = "fire_red"
|
||
elif "dolphin" in model_name_lower:
|
||
model_type = "dolphin_ctc"
|
||
elif "paraformer" in model_name_lower:
|
||
model_type = "paraformer"
|
||
elif "telespeech" in model_name_lower:
|
||
model_type = "telespeech_ctc"
|
||
elif "whisper" in model_name_lower:
|
||
model_type = "whisper"
|
||
elif "sense-voice" in model_name_lower:
|
||
model_type = "sensevoice"
|
||
elif "zipformer" in model_name_lower:
|
||
if "ctc" in model_name_lower:
|
||
model_type = "zipformer_ctc"
|
||
else:
|
||
model_type = "transducer"
|
||
elif "nemo" in model_name_lower:
|
||
if "ctc" in model_name_lower:
|
||
model_type = "nemo_ctc"
|
||
elif "canary" in model_name_lower:
|
||
model_type = "nemo_canary"
|
||
else:
|
||
model_type = "transducer"
|
||
elif "conformer" in model_name_lower or "lstm" in model_name_lower:
|
||
model_type = "transducer"
|
||
elif "wenet" in model_name_lower:
|
||
model_type = "wenet_ctc"
|
||
else:
|
||
model_type = "unknown"
|
||
return model_type
|
||
|
||
@app.on_event("startup")
|
||
def load_model():
|
||
global status, recognizer, device, model_type
|
||
config = app.state.config
|
||
use_gpu = config.get("use_gpu", True)
|
||
model_dir = config.get("model_dir", "/model")
|
||
_model_type = config.get("model_type", None)
|
||
_model_name = config.get("model_name", None)
|
||
warmup = config.get("warmup", False)
|
||
isOffline = config.get("offline_model", True)
|
||
num_threads = config.get("num_threads", 2)
|
||
|
||
device = "cpu"
|
||
if use_gpu:
|
||
if CUSTOM_DEVICE.startswith("mlu"):
|
||
device = "mlu:0"
|
||
elif CUSTOM_DEVICE.startswith("ascend"):
|
||
device = "npu:0"
|
||
else:
|
||
device = "cuda:0"
|
||
|
||
# sherpa-onnx类型繁杂,当用户清楚的时候可提供model_type参数,抑或是提供完整的模型名称也行
|
||
# 因为挂载进入镜像的时候镜像内的文件路径不一定包含了模型名称
|
||
if _model_type:
|
||
model_type = _model_type
|
||
elif _model_name:
|
||
model_type = get_asr_model_type(_model_name)
|
||
else:
|
||
print("model_name and model_type both not provided, start guessing using model_dir", flush=True)
|
||
model_name = os.path.basename(model_dir)
|
||
model_type = get_asr_model_type(model_name)
|
||
|
||
print(">> Startup config:")
|
||
print(" model_dir =", model_dir, flush=True)
|
||
print(" model_type =", model_type, flush=True)
|
||
print(" use_gpu =", use_gpu, flush=True)
|
||
print(" warmup =", warmup, flush=True)
|
||
print(" isOffline =", isOffline, flush=True)
|
||
print(" num_threads =", num_threads, flush=True)
|
||
|
||
try:
|
||
recognizer = None
|
||
provider = "cuda" if use_gpu else "cpu"
|
||
file_list = os.listdir(model_dir)
|
||
# 目录内的模型文件可能会有多套(例如量化和不带量化版),选取大小最大的那一套
|
||
if model_type == "whisper":
|
||
encoder_list, decoder_list = [], []
|
||
tokens = ""
|
||
for file in file_list:
|
||
if "encode" in file and file.endswith(".onnx"):
|
||
encoder_list.append(file)
|
||
elif "decode" in file and file.endswith(".onnx"):
|
||
decoder_list.append(file)
|
||
elif "token" in file and file.endswith(".txt"):
|
||
tokens = file
|
||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||
encoder=model_dir + "/" + encoder,
|
||
decoder=model_dir + "/" + decoder,
|
||
tokens=model_dir + "/" + tokens,
|
||
language="zh",
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
|
||
elif model_type == "sensevoice":
|
||
model_list = []
|
||
tokens = ""
|
||
for file in file_list:
|
||
if file.endswith(".onnx"):
|
||
model_list.append(file)
|
||
elif file.endswith(".txt") and "token" in file:
|
||
tokens = file
|
||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
use_itn=True,
|
||
language="zh",
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "paraformer":
|
||
model_list = []
|
||
for file in file_list:
|
||
if file.endswith(".onnx"):
|
||
model_list.append(file)
|
||
elif file.endswith(".txt") and "token" in file:
|
||
tokens = file
|
||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
if isOffline:
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||
paraformer=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
else:
|
||
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
|
||
paraformer=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "zipformer_ctc":
|
||
model_list = []
|
||
for file in file_list:
|
||
if file.endswith(".onnx"):
|
||
model_list.append(file)
|
||
elif file.endswith(".txt") and "token" in file:
|
||
tokens = file
|
||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
if isOffline:
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_zipformer_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
else:
|
||
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "telespeech_ctc":
|
||
model_list = []
|
||
for file in file_list:
|
||
if file.endswith(".onnx"):
|
||
model_list.append(file)
|
||
elif file.endswith(".txt") and "token" in file:
|
||
tokens = file
|
||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "fire_red":
|
||
encoder_list, decoder_list = [], []
|
||
tokens = ""
|
||
for file in file_list:
|
||
if "encode" in file and file.endswith(".onnx"):
|
||
encoder_list.append(file)
|
||
elif "decode" in file and file.endswith(".onnx"):
|
||
decoder_list.append(file)
|
||
elif "token" in file and file.endswith(".txt"):
|
||
tokens = file
|
||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
|
||
encoder=model_dir + "/" + encoder,
|
||
decoder=model_dir + "/" + decoder,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "wenet_ctc":
|
||
model_list = []
|
||
for file in file_list:
|
||
if file.endswith(".onnx"):
|
||
model_list.append(file)
|
||
elif file.endswith(".txt") and "token" in file:
|
||
tokens = file
|
||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
if isOffline:
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
else:
|
||
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "dolphin_ctc":
|
||
model_list = []
|
||
for file in file_list:
|
||
if file.endswith(".onnx"):
|
||
model_list.append(file)
|
||
elif file.endswith(".txt") and "token" in file:
|
||
tokens = file
|
||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_dolphin_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "transducer":
|
||
encoder_list, decoder_list, joiner_list = [], [], []
|
||
tokens = ""
|
||
for file in file_list:
|
||
if "encode" in file and file.endswith(".onnx"):
|
||
encoder_list.append(file)
|
||
elif "decode" in file and file.endswith(".onnx"):
|
||
decoder_list.append(file)
|
||
elif "joiner" in file and file.endswith(".onnx"):
|
||
joiner_list.append(file)
|
||
elif "token" in file and file.endswith(".txt"):
|
||
tokens = file
|
||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
joiner = sorted(joiner_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
# 特殊情况,zipformer,conformer都是icefall导出,默认类型即可,nemo-transducer需要专门区分
|
||
transducer_type = "nemo_transducer" if "nemo" in model_name.lower() else "transducer"
|
||
if isOffline:
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
||
encoder=model_dir + "/" + encoder,
|
||
decoder=model_dir + "/" + decoder,
|
||
joiner=model_dir + "/" + joiner,
|
||
tokens=model_dir + "/" + tokens,
|
||
model_type=transducer_type,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
else:
|
||
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
|
||
encoder=model_dir + "/" + encoder,
|
||
decoder=model_dir + "/" + decoder,
|
||
joiner=model_dir + "/" + joiner,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "nemo_ctc":
|
||
model_list = []
|
||
for file in file_list:
|
||
if file.endswith(".onnx"):
|
||
model_list.append(file)
|
||
elif file.endswith(".txt") and "token" in file:
|
||
tokens = file
|
||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
if isOffline:
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
else:
|
||
recognizer = sherpa_onnx.OnlineRecognizer.from_nemo_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "nemo_canary":
|
||
encoder_list, decoder_list = [], []
|
||
tokens = ""
|
||
for file in file_list:
|
||
if "encode" in file and file.endswith(".onnx"):
|
||
encoder_list.append(file)
|
||
elif "decode" in file and file.endswith(".onnx"):
|
||
decoder_list.append(file)
|
||
elif "token" in file and file.endswith(".txt"):
|
||
tokens = file
|
||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_canary(
|
||
encoder=model_dir + "/" + encoder,
|
||
decoder=model_dir + "/" + decoder,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "moonshine":
|
||
preprocessor_list, encoder_list, cached_decoder_list, uncached_decoder_list = [], [], [], []
|
||
tokens = ""
|
||
for file in file_list:
|
||
if "preprocess" in file and file.endswith(".onnx"):
|
||
preprocessor_list.append(file)
|
||
elif "encode" in file and file.endswith(".onnx"):
|
||
encoder_list.append(file)
|
||
elif "uncached_decode" in file and file.endswith(".onnx"):
|
||
uncached_decoder_list.append(file)
|
||
elif "cached_decode" in file and file.endswith(".onnx"):
|
||
cached_decoder_list.append(file)
|
||
elif "token" in file and file.endswith(".txt"):
|
||
tokens = file
|
||
preprocessor = sorted(preprocessor_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
cached_decoder = sorted(cached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
uncached_decoder = sorted(uncached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||
preprocessor=model_dir + "/" + preprocessor,
|
||
encoder=model_dir + "/" + encoder,
|
||
cached_decoder=model_dir + "/" + cached_decoder,
|
||
uncached_decoder=model_dir + "/" + uncached_decoder,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
elif model_type == "tdnn_ctc":
|
||
model_list = []
|
||
for file in file_list:
|
||
if file.endswith(".onnx"):
|
||
model_list.append(file)
|
||
elif file.endswith(".txt") and "token" in file:
|
||
tokens = file
|
||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
|
||
model=model_dir + "/" + model,
|
||
tokens=model_dir + "/" + tokens,
|
||
debug=False,
|
||
provider=provider,
|
||
num_threads=num_threads
|
||
)
|
||
else:
|
||
raise RuntimeError("Cannot recognize model_type")
|
||
except Exception as e:
|
||
raise RuntimeError(f"Failed to initial cuda model: {e}")
|
||
|
||
if warmup:
|
||
print("Start warmup...", flush=True)
|
||
stream = recognizer.create_stream()
|
||
audio, sample_rate = sf.read("warmup.wav", dtype="float32", always_2d=True)
|
||
stream.accept_waveform(sample_rate, audio)
|
||
recognizer.decode_stream(stream)
|
||
print("warmup complete.", flush=True)
|
||
|
||
status = "Success"
|
||
|
||
def test_sherpa(wavefile):
|
||
isOffline = app.state.config.get("offline_model", True)
|
||
audio, sample_rate = sf.read(wavefile, dtype="float32", always_2d=True)
|
||
audio = audio[:, 0]
|
||
generated_text = ""
|
||
|
||
start_t = datetime.datetime.now()
|
||
if isOffline:
|
||
# OfflineRecognizer非流式模型推理
|
||
if model_type in ["sensevoice"]:
|
||
stream = recognizer.create_stream()
|
||
stream.accept_waveform(sample_rate, audio)
|
||
recognizer.decode_stream(stream)
|
||
generated_text = stream.result.text
|
||
else:
|
||
# offline-asr model 大多对长音频支持不佳,模型训练音频不长以及导出onnx结构中对一些中间态维度可能有上限
|
||
# 哪怕原版CPU推理中间可能都会崩溃,采取小段切分形式测试
|
||
start_index = 0
|
||
internal = int(sample_rate * 29)
|
||
while start_index < len(audio):
|
||
stream = recognizer.create_stream()
|
||
stream.accept_waveform(sample_rate, audio[start_index:start_index + internal])
|
||
recognizer.decode_stream(stream)
|
||
generated_text += stream.result.text
|
||
start_index += internal
|
||
else:
|
||
# OnlineRecognizer流式模型推理,统一每一次只投喂2s音频数据
|
||
stream = recognizer.create_stream()
|
||
start_index = 0
|
||
chunk_size = int(sample_rate * 2)
|
||
|
||
while start_index < len(audio):
|
||
chunk = audio[start_index:start_index + chunk_size]
|
||
stream.accept_waveform(sample_rate, chunk)
|
||
|
||
while recognizer.is_ready(stream):
|
||
recognizer.decode_stream(stream)
|
||
# mid_text = recognizer.get_result(stream)
|
||
# print("partial result: " + mid_text, flush=True)
|
||
start_index += chunk_size
|
||
|
||
while recognizer.is_ready(stream):
|
||
recognizer.decode_stream(stream)
|
||
generated_text = recognizer.get_result(stream)
|
||
|
||
end_t = datetime.datetime.now()
|
||
elapsed_seconds = (end_t - start_t).total_seconds()
|
||
duration = audio.shape[-1] / sample_rate
|
||
rtf = elapsed_seconds / duration
|
||
|
||
print("Text:", generated_text, flush=True)
|
||
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
||
print(f"Elapsed:\t{elapsed_seconds:.3f} s", flush=True)
|
||
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
|
||
|
||
return generated_text
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
|
||
if status=="Running":
|
||
return {
|
||
"status":"loading model"
|
||
}
|
||
ret = {
|
||
"status": "ok" if status == "Success" else "failed",
|
||
}
|
||
return ret
|
||
|
||
@app.post("/transduce")
|
||
def transduce(
|
||
audio: UploadFile = File(...),
|
||
lang: str = Form("zh"),
|
||
background_tasks: BackgroundTasks = None
|
||
):
|
||
try:
|
||
file_path = f"./input/{uuid.uuid4()}.wav"
|
||
with open(file_path, "wb") as f:
|
||
f.write(audio.file.read())
|
||
background_tasks.add_task(os.remove, file_path)
|
||
generated_text = test_sherpa(file_path)
|
||
|
||
return {"generated_text": generated_text}
|
||
except Exception:
|
||
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
|
||
|
||
# if __name__ == "__main__":
|
||
|
||
# uvicorn.run("fastapi_sherpa:app", host="0.0.0.0", port=1111, workers=1) |