Fix paraformer Englishword split

This commit is contained in:
2026-02-09 13:48:45 +08:00
parent 15b838d17d
commit 718a5bd24d
2 changed files with 33 additions and 16 deletions

View File

@@ -74,7 +74,7 @@ def split_audio(waveform, sample_rate, segment_seconds=20):
@app.on_event("startup")
def load_model():
global status, model, device
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
@@ -85,7 +85,7 @@ def load_model():
print(" model_type =", model_type, flush=True)
print(" use_gpu =", use_gpu, flush=True)
print(" warmup =", warmup, flush=True)
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
@@ -94,7 +94,7 @@ def load_model():
device = "npu:1"
else:
device = "cuda:0"
# 针对加速卡的特殊处理部分
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
# 天垓100情况下的Whisper需要绕过不支持算子
@@ -102,14 +102,14 @@ def load_model():
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
print(f"device: {device}", flush=True)
dense_convert = False
if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper":
dense_convert = True
if device.startswith("npu") and model_type == "whisper":
# Ascend NPU 加载whisper的部分会有Sparse部分device不匹配
dense_convert = True
print(f"dense_convert: {dense_convert}", flush=True)
if dense_convert:
model = AutoModel(
@@ -132,23 +132,23 @@ def load_model():
device=device,
disable_update=True
)
if device.startswith("npu") or warmup:
# Ascend NPU由于底层设计的不同初始化卡的调度比其他卡更复杂要先进行warmup
print("Start warmup...", flush=True)
res = model.generate(input="warmup.wav")
print("warmup complete.", flush=True)
status = "Success"
def test_funasr(audio_file, lang):
# 推理部分
waveform, sample_rate = torchaudio.load(audio_file)
# print(waveform.shape)
duration = waveform.shape[1] / sample_rate
segments = split_audio(waveform, sample_rate, segment_seconds=20)
generated_text = ""
processing_time = 0
model_type = app.state.config.get("model_type", "sensevoice")
@@ -201,7 +201,8 @@ def test_funasr(audio_file, lang):
)
text = res[0]["text"]
# paraformer模型会一个字一个字输出中间夹太多空格会影响1-cer的结果
text = text.replace(" ", "")
if lang == "zh":
text = text.replace(" ", "")
elif model_type == "conformer":
res = model.generate(
input=segment_path,
@@ -222,7 +223,7 @@ def test_funasr(audio_file, lang):
generated_text += text
processing_time += (ts2 - ts1)
os.remove(segment_path)
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
@@ -255,11 +256,11 @@ def transduce(
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = test_funasr(file_path, lang)
return {"generated_text": generated_text}
except Exception:
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
# if __name__ == "__main__":
# uvicorn.run("fastapi_funasr:app", host="0.0.0.0", port=1111, workers=1)