Fix paraformer Englishword split
This commit is contained in:
@@ -74,7 +74,7 @@ def split_audio(waveform, sample_rate, segment_seconds=20):
|
||||
@app.on_event("startup")
|
||||
def load_model():
|
||||
global status, model, device
|
||||
|
||||
|
||||
config = app.state.config
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
model_dir = config.get("model_dir", "/model")
|
||||
@@ -85,7 +85,7 @@ def load_model():
|
||||
print(" model_type =", model_type, flush=True)
|
||||
print(" use_gpu =", use_gpu, flush=True)
|
||||
print(" warmup =", warmup, flush=True)
|
||||
|
||||
|
||||
device = "cpu"
|
||||
if use_gpu:
|
||||
if CUSTOM_DEVICE.startswith("mlu"):
|
||||
@@ -94,7 +94,7 @@ def load_model():
|
||||
device = "npu:1"
|
||||
else:
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
# 针对加速卡的特殊处理部分
|
||||
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
|
||||
# 天垓100情况下的Whisper需要绕过不支持算子
|
||||
@@ -102,14 +102,14 @@ def load_model():
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
print(f"device: {device}", flush=True)
|
||||
|
||||
|
||||
dense_convert = False
|
||||
if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper":
|
||||
dense_convert = True
|
||||
if device.startswith("npu") and model_type == "whisper":
|
||||
# Ascend NPU 加载whisper的部分会有Sparse部分device不匹配
|
||||
dense_convert = True
|
||||
|
||||
|
||||
print(f"dense_convert: {dense_convert}", flush=True)
|
||||
if dense_convert:
|
||||
model = AutoModel(
|
||||
@@ -132,23 +132,23 @@ def load_model():
|
||||
device=device,
|
||||
disable_update=True
|
||||
)
|
||||
|
||||
|
||||
if device.startswith("npu") or warmup:
|
||||
# Ascend NPU由于底层设计的不同,初始化卡的调度比其他卡更复杂,要先进行warmup
|
||||
print("Start warmup...", flush=True)
|
||||
res = model.generate(input="warmup.wav")
|
||||
print("warmup complete.", flush=True)
|
||||
|
||||
|
||||
status = "Success"
|
||||
|
||||
|
||||
|
||||
|
||||
def test_funasr(audio_file, lang):
|
||||
# 推理部分
|
||||
waveform, sample_rate = torchaudio.load(audio_file)
|
||||
# print(waveform.shape)
|
||||
duration = waveform.shape[1] / sample_rate
|
||||
segments = split_audio(waveform, sample_rate, segment_seconds=20)
|
||||
|
||||
|
||||
generated_text = ""
|
||||
processing_time = 0
|
||||
model_type = app.state.config.get("model_type", "sensevoice")
|
||||
@@ -201,7 +201,8 @@ def test_funasr(audio_file, lang):
|
||||
)
|
||||
text = res[0]["text"]
|
||||
# paraformer模型会一个字一个字输出,中间夹太多空格会影响1-cer的结果
|
||||
text = text.replace(" ", "")
|
||||
if lang == "zh":
|
||||
text = text.replace(" ", "")
|
||||
elif model_type == "conformer":
|
||||
res = model.generate(
|
||||
input=segment_path,
|
||||
@@ -222,7 +223,7 @@ def test_funasr(audio_file, lang):
|
||||
generated_text += text
|
||||
processing_time += (ts2 - ts1)
|
||||
os.remove(segment_path)
|
||||
|
||||
|
||||
rtf = processing_time / duration
|
||||
print("Text:", generated_text, flush=True)
|
||||
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
||||
@@ -255,11 +256,11 @@ def transduce(
|
||||
f.write(audio.file.read())
|
||||
background_tasks.add_task(os.remove, file_path)
|
||||
generated_text = test_funasr(file_path, lang)
|
||||
|
||||
|
||||
return {"generated_text": generated_text}
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
|
||||
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
|
||||
# uvicorn.run("fastapi_funasr:app", host="0.0.0.0", port=1111, workers=1)
|
||||
|
||||
Reference in New Issue
Block a user