Files
enginex-bi_series-asr/test_funasr.py

180 lines
7.2 KiB
Python
Raw Normal View History

2025-08-20 14:07:57 +08:00
import os
import time
import torchaudio
import torch
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from utils.calculate import cal_per_cer
import json
def split_audio(waveform, sample_rate, segment_seconds=20):
segment_samples = segment_seconds * sample_rate
segments = []
for i in range(0, waveform.shape[1], segment_samples):
segment = waveform[:, i:i + segment_samples]
if segment.shape[1] > 0:
segments.append(segment)
return segments
def determine_model_type(model_name):
if "sensevoice" in model_name.lower():
return "sense_voice"
elif "whisper" in model_name.lower():
return "whisper"
elif "paraformer" in model_name.lower():
return "paraformer"
elif "conformer" in model_name.lower():
return "conformer"
elif "uniasr" in model_name.lower():
return "uni_asr"
else:
return "unknown"
def test_funasr(model_dir, audio_file, answer_file, use_gpu):
model_name = os.path.basename(model_dir)
model_type = determine_model_type(model_name)
if torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
# 天垓100情况下的Whisper需要绕过不支持算子
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
# 不使用VAD, punctspk模型就测试原始ASR能力
model = AutoModel(
model=model_dir,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
vad_model=None,
device="cuda:0" if use_gpu else "cpu",
disable_update=True
)
waveform, sample_rate = torchaudio.load(audio_file)
# print(waveform.shape)
duration = waveform.shape[1] / sample_rate
segments = split_audio(waveform, sample_rate, segment_seconds=20)
generated_text = ""
processing_time = 0
if model_type == "uni_asr":
# uni_asr比较特殊设计就是处理长音频的自带VAD切分的话前20s如果几乎没有人讲话全是音乐直接会报错
# 因为可能会被切掉所有音频导致实际编解码输入为0
ts1 = time.time()
res = model.generate(
input=audio_file
)
generated_text = res[0]["text"]
ts2 = time.time()
processing_time = ts2 - ts1
else:
# 按照切分的音频依次输入
for i, segment in enumerate(segments):
segment_path = f"temp_seg_{i}.wav"
torchaudio.save(segment_path, segment, sample_rate)
ts1 = time.time()
if model_type == "sense_voice":
res = model.generate(
input=segment_path,
cache={},
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=False,
# merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
elif model_type == "whisper":
DecodingOptions = {
"task": "transcribe",
"language": "zh",
"beam_size": None,
"fp16": False,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
DecodingOptions=DecodingOptions,
input=segment_path,
batch_size_s=0,
)
text = res[0]["text"]
elif model_type == "paraformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# paraformer模型会一个字一个字输出中间夹太多空格会影响1-cer的结果
text = text.replace(" ", "")
elif model_type == "conformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# elif model_type == "uni_asr":
# if i == 0:
# os.remove(segment_path)
# continue
# res = model.generate(
# input=segment_path
# )
# text = res[0]["text"]
else:
raise RuntimeError("unknown model type")
ts2 = time.time()
generated_text += text
processing_time += (ts2 - ts1)
os.remove(segment_path)
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
with open(answer_file, 'r', encoding='utf-8') as f:
groundtruth_text = f.read()
acc = cal_per_cer(generated_text, groundtruth_text, "zh")
print(f"1-cer = {acc}", flush=True)
return processing_time, acc, generated_text
if __name__ == "__main__":
test_result = {
"time_cuda": 0,
"acc_cuda": 0,
"text_cuda": "",
"success": False
}
try:
model_dict = {
"sense_voice": "SenseVoiceSmall",
"paraformer": "speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
"conformer": "speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
"whisper": "Whisper-large-v3",
"uni_asr": "speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline"
}
LOCAL_TEST = os.getenv("LOCAL_TEST", "false").lower() == "true"
K8S_TEST = os.getenv("K8S_TEST", "false").lower() == "true"
workspace_path = "../" if LOCAL_TEST else "/tmp/workspace"
model_dir = os.path.join("/model", model_dict["sense_voice"]) if LOCAL_TEST else os.environ["MODEL_DIR"]
audio_file = "lei-jun-test.wav" if LOCAL_TEST else os.path.join(workspace_path, os.environ["TEST_FILE"])
answer_file = "lei-jun.txt" if LOCAL_TEST else os.path.join(workspace_path, os.environ["ANSWER_FILE"])
result_file = "result.json" if LOCAL_TEST else os.path.join(workspace_path, os.environ["RESULT_FILE"])
# test_funasr(model_dir, audio_file, answer_file, False)
processing_time, acc, generated_text = test_funasr(model_dir, audio_file, answer_file, True)
test_result["time_cuda"] = processing_time
test_result["acc_cuda"] = acc
test_result["text_cuda"] = generated_text
test_result["success"] = True
except Exception as e:
print(f"ASR测试出错: {e}", flush=True)
with open(result_file, "w", encoding="utf-8") as fp:
json.dump(test_result, fp, ensure_ascii=False, indent=4)
# 如果是SUT起来镜像的话需要加上下面让pod永不停止以迎合k8s deployment, 本地测试以及docker run均不需要
if K8S_TEST:
print(f"Start to sleep indefinitely", flush=True)
time.sleep(100000)