initial commit
This commit is contained in:
25
sherpa-onnx/Dockerfile.sherpa-onnx-bi150
Normal file
25
sherpa-onnx/Dockerfile.sherpa-onnx-bi150
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM corex:4.3.8
|
||||
|
||||
WORKDIR /root
|
||||
|
||||
RUN set -eux; \
|
||||
# 1) 把 aliyun 源替换成官方源(避免 403)
|
||||
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list; \
|
||||
sed -i -E 's|http://mirrors\.aliyun\.com/ubuntu|http://archive.ubuntu.com/ubuntu|g' /etc/apt/sources.list.d/*.list 2>/dev/null || true; \
|
||||
\
|
||||
# 2) 更新并安装
|
||||
apt-get update; \
|
||||
apt-get install -y --no-install-recommends vim net-tools ca-certificates libasound2-dev patchelf; \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ADD . /root/
|
||||
|
||||
COPY requirements.txt /root
|
||||
RUN pip install -r requirements.txt -i https://nexus.4pd.io/repository/pypi-all/simple --extra-index-url https://mirror.sjtu.edu.cn/pypi/web/simple
|
||||
COPY sherpa_onnx-1.12.5+corex4.3.8-cp310-cp310-linux_x86_64.whl /root
|
||||
RUN pip install ./sherpa_onnx-1.12.5+corex4.3.8-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
ENV LD_LIBRARY_PATH=/usr/local/corex-4.3.8/lib64/python3/dist-packages/tvm/:$LD_LIBRARY_PATH
|
||||
|
||||
ENTRYPOINT ["python3"]
|
||||
CMD ["./main_sherpa.py"]
|
||||
28
sherpa-onnx/README.md
Normal file
28
sherpa-onnx/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# 天数智芯 天垓150 ASR(Sherpa-ONNX架构)
|
||||
|
||||
## 镜像构造
|
||||
```shell
|
||||
docker build -f ./Dockerfile.sherpa-onnx-bi150 -t <your_image> .
|
||||
```
|
||||
其中,基础镜像 corex:4.3.8 通过联系天数智芯智铠100厂商技术支持可获取
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 使用 FastAPI 启动ASR服务:
|
||||
例如:
|
||||
```shell
|
||||
docker run -dit -v /usr/src:/usr/src -v /lib/modules:/lib/modules --device=/dev/iluvatar0:/dev/iluvatar0 \
|
||||
-v /mnt/contest_ceph/leaderboard/modelHubXC/mariolux/sherpa-onnx-dolphin-small-ctc-multi-lang-2025-04-02:/model \
|
||||
--network=host <your_image> \
|
||||
main_sherpa.py --model_dir /model --model_type dolphon_ctc --offline_model --use_gpu --port 1111
|
||||
```
|
||||
具体参数代码设定可参考代码文件
|
||||
|
||||
### 测试ASR服务
|
||||
项目根路径`sample_data`目录下附带上了中文的测试音频和附带内容
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:1111/transduce \
|
||||
-F "audio=@../sample_data/lei-jun-test.wav" \
|
||||
-F "lang=zh"
|
||||
```
|
||||
520
sherpa-onnx/fastapi_sherpa.py
Normal file
520
sherpa-onnx/fastapi_sherpa.py
Normal file
@@ -0,0 +1,520 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import datetime
|
||||
import tempfile
|
||||
import soundfile as sf
|
||||
import sherpa_onnx
|
||||
import traceback
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
|
||||
|
||||
os.makedirs("./input", exist_ok=True)
|
||||
status = "Running"
|
||||
recognizer = None
|
||||
device = ""
|
||||
model_type = ""
|
||||
app = FastAPI()
|
||||
|
||||
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
|
||||
|
||||
# 根据名称判断模型类型,比较杂,一共种类的自定义类型包括(针对OfflineRecognizer)
|
||||
# moonshine
|
||||
# fire_red
|
||||
# dolphin_ctc
|
||||
# paraformer
|
||||
# telespeech_ctc
|
||||
# whisper
|
||||
# sensevoice
|
||||
# zipformer_ctc
|
||||
# transducer
|
||||
# nemo_ctc
|
||||
# nemo_canary
|
||||
# wenet_ctc
|
||||
# 针对OnlineRecognizer只有 zipformer_ctc transducer paraformer nemo_ctc wenet_ctc 四种
|
||||
def get_asr_model_type(model_name):
|
||||
# 根据名称判断模型类型以及需要检测的语种任务
|
||||
# nemo_ctc, nemo_canary, moonshine 目前sherpa-onnx没有中文模型,执行英文ASR任务,其余模型执行中文ASR
|
||||
# 所有nemo模型(nemo_ctc, nemo_canary以及transuducer中的nemo模型)均无中文模型
|
||||
# 英文模型也并非全部大类都支持
|
||||
|
||||
# 特殊规则
|
||||
# zipformer带ctc的才属于zipformer_ctc那一类,否则属于transducer类
|
||||
# nemo也是带上ctc或者canary才属于单独类别,否则属于transducer类
|
||||
# conformer均为transducer类,但是得在nemo之后判断
|
||||
# wenet 由于同时wenetspeech为数据集名称,各种类型都有可能,这个逻辑需放在后面
|
||||
model_type = "unknown"
|
||||
model_name_lower = model_name.lower()
|
||||
if "tdnn" in model_name_lower:
|
||||
model_type = "tdnn" # tdnn类别不适用,目前仅有一个模型只能识别希伯来语中的yes/no两种词语
|
||||
elif "moonshine" in model_name_lower:
|
||||
model_type = "moonshine"
|
||||
elif "fire-red" in model_name_lower:
|
||||
model_type = "fire_red"
|
||||
elif "dolphin" in model_name_lower:
|
||||
model_type = "dolphin_ctc"
|
||||
elif "paraformer" in model_name_lower:
|
||||
model_type = "paraformer"
|
||||
elif "telespeech" in model_name_lower:
|
||||
model_type = "telespeech_ctc"
|
||||
elif "whisper" in model_name_lower:
|
||||
model_type = "whisper"
|
||||
elif "sense-voice" in model_name_lower:
|
||||
model_type = "sensevoice"
|
||||
elif "zipformer" in model_name_lower:
|
||||
if "ctc" in model_name_lower:
|
||||
model_type = "zipformer_ctc"
|
||||
else:
|
||||
model_type = "transducer"
|
||||
elif "nemo" in model_name_lower:
|
||||
if "ctc" in model_name_lower:
|
||||
model_type = "nemo_ctc"
|
||||
elif "canary" in model_name_lower:
|
||||
model_type = "nemo_canary"
|
||||
else:
|
||||
model_type = "transducer"
|
||||
elif "conformer" in model_name_lower or "lstm" in model_name_lower:
|
||||
model_type = "transducer"
|
||||
elif "wenet" in model_name_lower:
|
||||
model_type = "wenet_ctc"
|
||||
else:
|
||||
model_type = "unknown"
|
||||
return model_type
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_model():
|
||||
global status, recognizer, device, model_type
|
||||
config = app.state.config
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
model_dir = config.get("model_dir", "/model")
|
||||
_model_type = config.get("model_type", None)
|
||||
_model_name = config.get("model_name", None)
|
||||
warmup = config.get("warmup", False)
|
||||
isOffline = config.get("offline_model", True)
|
||||
num_threads = config.get("num_threads", 2)
|
||||
|
||||
device = "cpu"
|
||||
if use_gpu:
|
||||
if CUSTOM_DEVICE.startswith("mlu"):
|
||||
device = "mlu:0"
|
||||
elif CUSTOM_DEVICE.startswith("ascend"):
|
||||
device = "npu:0"
|
||||
else:
|
||||
device = "cuda:0"
|
||||
|
||||
# sherpa-onnx类型繁杂,当用户清楚的时候可提供model_type参数,抑或是提供完整的模型名称也行
|
||||
# 因为挂载进入镜像的时候镜像内的文件路径不一定包含了模型名称
|
||||
if _model_type:
|
||||
model_type = _model_type
|
||||
elif _model_name:
|
||||
model_type = get_asr_model_type(_model_name)
|
||||
else:
|
||||
print("model_name and model_type both not provided, start guessing using model_dir", flush=True)
|
||||
model_name = os.path.basename(model_dir)
|
||||
model_type = get_asr_model_type(model_name)
|
||||
|
||||
print(">> Startup config:")
|
||||
print(" model_dir =", model_dir, flush=True)
|
||||
print(" model_type =", model_type, flush=True)
|
||||
print(" use_gpu =", use_gpu, flush=True)
|
||||
print(" warmup =", warmup, flush=True)
|
||||
print(" isOffline =", isOffline, flush=True)
|
||||
print(" num_threads =", num_threads, flush=True)
|
||||
|
||||
try:
|
||||
recognizer = None
|
||||
provider = "cuda" if use_gpu else "cpu"
|
||||
file_list = os.listdir(model_dir)
|
||||
# 目录内的模型文件可能会有多套(例如量化和不带量化版),选取大小最大的那一套
|
||||
if model_type == "whisper":
|
||||
encoder_list, decoder_list = [], []
|
||||
tokens = ""
|
||||
for file in file_list:
|
||||
if "encode" in file and file.endswith(".onnx"):
|
||||
encoder_list.append(file)
|
||||
elif "decode" in file and file.endswith(".onnx"):
|
||||
decoder_list.append(file)
|
||||
elif "token" in file and file.endswith(".txt"):
|
||||
tokens = file
|
||||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||
encoder=model_dir + "/" + encoder,
|
||||
decoder=model_dir + "/" + decoder,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
language="zh",
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
|
||||
elif model_type == "sensevoice":
|
||||
model_list = []
|
||||
tokens = ""
|
||||
for file in file_list:
|
||||
if file.endswith(".onnx"):
|
||||
model_list.append(file)
|
||||
elif file.endswith(".txt") and "token" in file:
|
||||
tokens = file
|
||||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
use_itn=True,
|
||||
language="zh",
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "paraformer":
|
||||
model_list = []
|
||||
for file in file_list:
|
||||
if file.endswith(".onnx"):
|
||||
model_list.append(file)
|
||||
elif file.endswith(".txt") and "token" in file:
|
||||
tokens = file
|
||||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
if isOffline:
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
|
||||
paraformer=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "zipformer_ctc":
|
||||
model_list = []
|
||||
for file in file_list:
|
||||
if file.endswith(".onnx"):
|
||||
model_list.append(file)
|
||||
elif file.endswith(".txt") and "token" in file:
|
||||
tokens = file
|
||||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
if isOffline:
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_zipformer_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "telespeech_ctc":
|
||||
model_list = []
|
||||
for file in file_list:
|
||||
if file.endswith(".onnx"):
|
||||
model_list.append(file)
|
||||
elif file.endswith(".txt") and "token" in file:
|
||||
tokens = file
|
||||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "fire_red":
|
||||
encoder_list, decoder_list = [], []
|
||||
tokens = ""
|
||||
for file in file_list:
|
||||
if "encode" in file and file.endswith(".onnx"):
|
||||
encoder_list.append(file)
|
||||
elif "decode" in file and file.endswith(".onnx"):
|
||||
decoder_list.append(file)
|
||||
elif "token" in file and file.endswith(".txt"):
|
||||
tokens = file
|
||||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
|
||||
encoder=model_dir + "/" + encoder,
|
||||
decoder=model_dir + "/" + decoder,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "wenet_ctc":
|
||||
model_list = []
|
||||
for file in file_list:
|
||||
if file.endswith(".onnx"):
|
||||
model_list.append(file)
|
||||
elif file.endswith(".txt") and "token" in file:
|
||||
tokens = file
|
||||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
if isOffline:
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "dolphin_ctc":
|
||||
model_list = []
|
||||
for file in file_list:
|
||||
if file.endswith(".onnx"):
|
||||
model_list.append(file)
|
||||
elif file.endswith(".txt") and "token" in file:
|
||||
tokens = file
|
||||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_dolphin_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "transducer":
|
||||
encoder_list, decoder_list, joiner_list = [], [], []
|
||||
tokens = ""
|
||||
for file in file_list:
|
||||
if "encode" in file and file.endswith(".onnx"):
|
||||
encoder_list.append(file)
|
||||
elif "decode" in file and file.endswith(".onnx"):
|
||||
decoder_list.append(file)
|
||||
elif "joiner" in file and file.endswith(".onnx"):
|
||||
joiner_list.append(file)
|
||||
elif "token" in file and file.endswith(".txt"):
|
||||
tokens = file
|
||||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
joiner = sorted(joiner_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
# 特殊情况,zipformer,conformer都是icefall导出,默认类型即可,nemo-transducer需要专门区分
|
||||
transducer_type = "nemo_transducer" if "nemo" in model_name.lower() else "transducer"
|
||||
if isOffline:
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
||||
encoder=model_dir + "/" + encoder,
|
||||
decoder=model_dir + "/" + decoder,
|
||||
joiner=model_dir + "/" + joiner,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
model_type=transducer_type,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
|
||||
encoder=model_dir + "/" + encoder,
|
||||
decoder=model_dir + "/" + decoder,
|
||||
joiner=model_dir + "/" + joiner,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "nemo_ctc":
|
||||
model_list = []
|
||||
for file in file_list:
|
||||
if file.endswith(".onnx"):
|
||||
model_list.append(file)
|
||||
elif file.endswith(".txt") and "token" in file:
|
||||
tokens = file
|
||||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
if isOffline:
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_nemo_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "nemo_canary":
|
||||
encoder_list, decoder_list = [], []
|
||||
tokens = ""
|
||||
for file in file_list:
|
||||
if "encode" in file and file.endswith(".onnx"):
|
||||
encoder_list.append(file)
|
||||
elif "decode" in file and file.endswith(".onnx"):
|
||||
decoder_list.append(file)
|
||||
elif "token" in file and file.endswith(".txt"):
|
||||
tokens = file
|
||||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
decoder = sorted(decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_canary(
|
||||
encoder=model_dir + "/" + encoder,
|
||||
decoder=model_dir + "/" + decoder,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "moonshine":
|
||||
preprocessor_list, encoder_list, cached_decoder_list, uncached_decoder_list = [], [], [], []
|
||||
tokens = ""
|
||||
for file in file_list:
|
||||
if "preprocess" in file and file.endswith(".onnx"):
|
||||
preprocessor_list.append(file)
|
||||
elif "encode" in file and file.endswith(".onnx"):
|
||||
encoder_list.append(file)
|
||||
elif "uncached_decode" in file and file.endswith(".onnx"):
|
||||
uncached_decoder_list.append(file)
|
||||
elif "cached_decode" in file and file.endswith(".onnx"):
|
||||
cached_decoder_list.append(file)
|
||||
elif "token" in file and file.endswith(".txt"):
|
||||
tokens = file
|
||||
preprocessor = sorted(preprocessor_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
encoder = sorted(encoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
cached_decoder = sorted(cached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
uncached_decoder = sorted(uncached_decoder_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
||||
preprocessor=model_dir + "/" + preprocessor,
|
||||
encoder=model_dir + "/" + encoder,
|
||||
cached_decoder=model_dir + "/" + cached_decoder,
|
||||
uncached_decoder=model_dir + "/" + uncached_decoder,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
elif model_type == "tdnn_ctc":
|
||||
model_list = []
|
||||
for file in file_list:
|
||||
if file.endswith(".onnx"):
|
||||
model_list.append(file)
|
||||
elif file.endswith(".txt") and "token" in file:
|
||||
tokens = file
|
||||
model = sorted(model_list, key=lambda x: os.path.getsize(model_dir + "/" + x), reverse=True)[0]
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
|
||||
model=model_dir + "/" + model,
|
||||
tokens=model_dir + "/" + tokens,
|
||||
debug=False,
|
||||
provider=provider,
|
||||
num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Cannot recognize model_type")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initial cuda model: {e}")
|
||||
|
||||
if warmup:
|
||||
print("Start warmup...", flush=True)
|
||||
stream = recognizer.create_stream()
|
||||
audio, sample_rate = sf.read("warmup.wav", dtype="float32", always_2d=True)
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
recognizer.decode_stream(stream)
|
||||
print("warmup complete.", flush=True)
|
||||
|
||||
status = "Success"
|
||||
|
||||
def test_sherpa(wavefile):
|
||||
isOffline = app.state.config.get("offline_model", True)
|
||||
audio, sample_rate = sf.read(wavefile, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0]
|
||||
generated_text = ""
|
||||
|
||||
start_t = datetime.datetime.now()
|
||||
if isOffline:
|
||||
# OfflineRecognizer非流式模型推理
|
||||
if model_type in ["sensevoice"]:
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
recognizer.decode_stream(stream)
|
||||
generated_text = stream.result.text
|
||||
else:
|
||||
# offline-asr model 大多对长音频支持不佳,模型训练音频不长以及导出onnx结构中对一些中间态维度可能有上限
|
||||
# 哪怕原版CPU推理中间可能都会崩溃,采取小段切分形式测试
|
||||
start_index = 0
|
||||
internal = int(sample_rate * 29)
|
||||
while start_index < len(audio):
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, audio[start_index:start_index + internal])
|
||||
recognizer.decode_stream(stream)
|
||||
generated_text += stream.result.text
|
||||
start_index += internal
|
||||
else:
|
||||
# OnlineRecognizer流式模型推理,统一每一次只投喂2s音频数据
|
||||
stream = recognizer.create_stream()
|
||||
start_index = 0
|
||||
chunk_size = int(sample_rate * 2)
|
||||
|
||||
while start_index < len(audio):
|
||||
chunk = audio[start_index:start_index + chunk_size]
|
||||
stream.accept_waveform(sample_rate, chunk)
|
||||
|
||||
while recognizer.is_ready(stream):
|
||||
recognizer.decode_stream(stream)
|
||||
# mid_text = recognizer.get_result(stream)
|
||||
# print("partial result: " + mid_text, flush=True)
|
||||
start_index += chunk_size
|
||||
|
||||
while recognizer.is_ready(stream):
|
||||
recognizer.decode_stream(stream)
|
||||
generated_text = recognizer.get_result(stream)
|
||||
|
||||
end_t = datetime.datetime.now()
|
||||
elapsed_seconds = (end_t - start_t).total_seconds()
|
||||
duration = audio.shape[-1] / sample_rate
|
||||
rtf = elapsed_seconds / duration
|
||||
|
||||
print("Text:", generated_text, flush=True)
|
||||
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
||||
print(f"Elapsed:\t{elapsed_seconds:.3f} s", flush=True)
|
||||
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
|
||||
|
||||
return generated_text
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
|
||||
if status=="Running":
|
||||
return {
|
||||
"status":"loading model"
|
||||
}
|
||||
ret = {
|
||||
"status": "ok" if status == "Success" else "failed",
|
||||
}
|
||||
return ret
|
||||
|
||||
@app.post("/transduce")
|
||||
def transduce(
|
||||
audio: UploadFile = File(...),
|
||||
lang: str = Form("zh"),
|
||||
background_tasks: BackgroundTasks = None
|
||||
):
|
||||
try:
|
||||
file_path = f"./input/{uuid.uuid4()}.wav"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(audio.file.read())
|
||||
background_tasks.add_task(os.remove, file_path)
|
||||
generated_text = test_sherpa(file_path)
|
||||
|
||||
return {"generated_text": generated_text}
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
# uvicorn.run("fastapi_sherpa:app", host="0.0.0.0", port=1111, workers=1)
|
||||
33
sherpa-onnx/main_sherpa.py
Normal file
33
sherpa-onnx/main_sherpa.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi_sherpa import app
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_dir", type=str, default="/model", required=True, help="model directory")
|
||||
parser.add_argument("--model_type", type=str, default=None, help="model type, e.g. sensevoice")
|
||||
parser.add_argument("--use_gpu", action="store_true", default=True)
|
||||
parser.add_argument("--warmup", action="store_true", help="whether do warmup when first initializing model")
|
||||
parser.add_argument("--model_name", type=str, default=None, help="model's full name(optional) to determine model type")
|
||||
parser.add_argument("--num_threads", type=int, default=2, help="number of threads with model inference")
|
||||
parser.add_argument("--offline_model", action="store_true", help="indicating a non-streaming model when this flag is set")
|
||||
parser.add_argument("--port", type=int, default=8000, help="service port")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 将参数加到 app.state 中
|
||||
app.state.config = {
|
||||
"model_dir": args.model_dir,
|
||||
"model_type": args.model_type,
|
||||
"model_name": args.model_name,
|
||||
"num_threads": args.num_threads,
|
||||
"offline_model": args.offline_model,
|
||||
"use_gpu": args.use_gpu, # True
|
||||
"warmup": args.warmup,
|
||||
}
|
||||
|
||||
uvicorn.run("fastapi_sherpa:app",
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
workers=1
|
||||
)
|
||||
14
sherpa-onnx/requirements.txt
Normal file
14
sherpa-onnx/requirements.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
requests
|
||||
wheel
|
||||
websocket-client
|
||||
pydantic>=2.0.0
|
||||
numpy<2.0
|
||||
PYYaml
|
||||
Levenshtein
|
||||
ruamel.yaml
|
||||
nltk==3.7
|
||||
pynini==2.1.6
|
||||
soundfile
|
||||
fastapi
|
||||
uvicorn
|
||||
python-multipart
|
||||
Binary file not shown.
BIN
sherpa-onnx/warmup.wav
Normal file
BIN
sherpa-onnx/warmup.wav
Normal file
Binary file not shown.
Reference in New Issue
Block a user