Files
enginex-mr_series-asr/test_scripts/test_funasr.py
2026-02-04 17:34:39 +08:00

180 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import torchaudio
import torch
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from utils.calculate import cal_per_cer
import json
def split_audio(waveform, sample_rate, segment_seconds=20):
segment_samples = segment_seconds * sample_rate
segments = []
for i in range(0, waveform.shape[1], segment_samples):
segment = waveform[:, i:i + segment_samples]
if segment.shape[1] > 0:
segments.append(segment)
return segments
def determine_model_type(model_name):
if "sensevoice" in model_name.lower():
return "sense_voice"
elif "whisper" in model_name.lower():
return "whisper"
elif "paraformer" in model_name.lower():
return "paraformer"
elif "conformer" in model_name.lower():
return "conformer"
elif "uniasr" in model_name.lower():
return "uni_asr"
else:
return "unknown"
def test_funasr(model_dir, audio_file, answer_file, use_gpu):
model_name = os.path.basename(model_dir)
model_type = determine_model_type(model_name)
if torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
# 天垓100情况下的Whisper需要绕过不支持算子
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
# 不使用VAD, punctspk模型就测试原始ASR能力
model = AutoModel(
model=model_dir,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
vad_model=None,
device="cuda:0" if use_gpu else "cpu",
disable_update=True
)
waveform, sample_rate = torchaudio.load(audio_file)
# print(waveform.shape)
duration = waveform.shape[1] / sample_rate
segments = split_audio(waveform, sample_rate, segment_seconds=20)
generated_text = ""
processing_time = 0
if model_type == "uni_asr":
# uni_asr比较特殊设计就是处理长音频的自带VAD切分的话前20s如果几乎没有人讲话全是音乐直接会报错
# 因为可能会被切掉所有音频导致实际编解码输入为0
ts1 = time.time()
res = model.generate(
input=audio_file
)
generated_text = res[0]["text"]
ts2 = time.time()
processing_time = ts2 - ts1
else:
# 按照切分的音频依次输入
for i, segment in enumerate(segments):
segment_path = f"temp_seg_{i}.wav"
torchaudio.save(segment_path, segment, sample_rate)
ts1 = time.time()
if model_type == "sense_voice":
res = model.generate(
input=segment_path,
cache={},
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=False,
# merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
elif model_type == "whisper":
DecodingOptions = {
"task": "transcribe",
"language": "zh",
"beam_size": None,
"fp16": False,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
DecodingOptions=DecodingOptions,
input=segment_path,
batch_size_s=0,
)
text = res[0]["text"]
elif model_type == "paraformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# paraformer模型会一个字一个字输出中间夹太多空格会影响1-cer的结果
text = text.replace(" ", "")
elif model_type == "conformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# elif model_type == "uni_asr":
# if i == 0:
# os.remove(segment_path)
# continue
# res = model.generate(
# input=segment_path
# )
# text = res[0]["text"]
else:
raise RuntimeError("unknown model type")
ts2 = time.time()
generated_text += text
processing_time += (ts2 - ts1)
os.remove(segment_path)
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
with open(answer_file, 'r', encoding='utf-8') as f:
groundtruth_text = f.read()
acc = cal_per_cer(generated_text, groundtruth_text, "zh")
print(f"1-cer = {acc}", flush=True)
return processing_time, acc, generated_text
if __name__ == "__main__":
test_result = {
"time_cuda": 0,
"acc_cuda": 0,
"text_cuda": "",
"success": False
}
try:
model_dict = {
"sense_voice": "SenseVoiceSmall",
"paraformer": "speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
"conformer": "speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
"whisper": "Whisper-large-v3",
"uni_asr": "speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline"
}
LOCAL_TEST = os.getenv("LOCAL_TEST", "false").lower() == "true"
K8S_TEST = os.getenv("K8S_TEST", "false").lower() == "true"
workspace_path = "../" if LOCAL_TEST else "/tmp/workspace"
model_dir = os.path.join("/model", model_dict["sense_voice"]) if LOCAL_TEST else os.environ["MODEL_DIR"]
audio_file = "lei-jun-test.wav" if LOCAL_TEST else os.path.join(workspace_path, os.environ["TEST_FILE"])
answer_file = "lei-jun.txt" if LOCAL_TEST else os.path.join(workspace_path, os.environ["ANSWER_FILE"])
result_file = "result.json" if LOCAL_TEST else os.path.join(workspace_path, os.environ["RESULT_FILE"])
# test_funasr(model_dir, audio_file, answer_file, False)
processing_time, acc, generated_text = test_funasr(model_dir, audio_file, answer_file, True)
test_result["time_cuda"] = processing_time
test_result["acc_cuda"] = acc
test_result["text_cuda"] = generated_text
test_result["success"] = True
except Exception as e:
print(f"ASR测试出错: {e}", flush=True)
with open(result_file, "w", encoding="utf-8") as fp:
json.dump(test_result, fp, ensure_ascii=False, indent=4)
# 如果是SUT起来镜像的话需要加上下面让pod永不停止以迎合k8s deployment, 本地测试以及docker run均不需要
if K8S_TEST:
print(f"Start to sleep indefinitely", flush=True)
time.sleep(100000)