180 lines
7.2 KiB
Python
180 lines
7.2 KiB
Python
|
|
import os
|
|||
|
|
import time
|
|||
|
|
import torchaudio
|
|||
|
|
import torch
|
|||
|
|
from funasr import AutoModel
|
|||
|
|
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
|||
|
|
from utils.calculate import cal_per_cer
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
def split_audio(waveform, sample_rate, segment_seconds=20):
|
|||
|
|
segment_samples = segment_seconds * sample_rate
|
|||
|
|
segments = []
|
|||
|
|
for i in range(0, waveform.shape[1], segment_samples):
|
|||
|
|
segment = waveform[:, i:i + segment_samples]
|
|||
|
|
if segment.shape[1] > 0:
|
|||
|
|
segments.append(segment)
|
|||
|
|
return segments
|
|||
|
|
|
|||
|
|
def determine_model_type(model_name):
|
|||
|
|
if "sensevoice" in model_name.lower():
|
|||
|
|
return "sense_voice"
|
|||
|
|
elif "whisper" in model_name.lower():
|
|||
|
|
return "whisper"
|
|||
|
|
elif "paraformer" in model_name.lower():
|
|||
|
|
return "paraformer"
|
|||
|
|
elif "conformer" in model_name.lower():
|
|||
|
|
return "conformer"
|
|||
|
|
elif "uniasr" in model_name.lower():
|
|||
|
|
return "uni_asr"
|
|||
|
|
else:
|
|||
|
|
return "unknown"
|
|||
|
|
|
|||
|
|
def test_funasr(model_dir, audio_file, answer_file, use_gpu):
|
|||
|
|
model_name = os.path.basename(model_dir)
|
|||
|
|
model_type = determine_model_type(model_name)
|
|||
|
|
|
|||
|
|
if torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
|
|||
|
|
# 天垓100情况下的Whisper需要绕过不支持算子
|
|||
|
|
torch.backends.cuda.enable_flash_sdp(False)
|
|||
|
|
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
|||
|
|
torch.backends.cuda.enable_math_sdp(True)
|
|||
|
|
|
|||
|
|
# 不使用VAD, punct,spk模型,就测试原始ASR能力
|
|||
|
|
model = AutoModel(
|
|||
|
|
model=model_dir,
|
|||
|
|
# vad_model="fsmn-vad",
|
|||
|
|
# vad_kwargs={"max_single_segment_time": 30000},
|
|||
|
|
vad_model=None,
|
|||
|
|
device="cuda:0" if use_gpu else "cpu",
|
|||
|
|
disable_update=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
waveform, sample_rate = torchaudio.load(audio_file)
|
|||
|
|
# print(waveform.shape)
|
|||
|
|
duration = waveform.shape[1] / sample_rate
|
|||
|
|
segments = split_audio(waveform, sample_rate, segment_seconds=20)
|
|||
|
|
|
|||
|
|
generated_text = ""
|
|||
|
|
processing_time = 0
|
|||
|
|
|
|||
|
|
if model_type == "uni_asr":
|
|||
|
|
# uni_asr比较特殊,设计就是处理长音频的(自带VAD),切分的话前20s如果几乎没有人讲话全是音乐直接会报错
|
|||
|
|
# 因为可能会被切掉所有音频导致实际编解码输入为0
|
|||
|
|
ts1 = time.time()
|
|||
|
|
res = model.generate(
|
|||
|
|
input=audio_file
|
|||
|
|
)
|
|||
|
|
generated_text = res[0]["text"]
|
|||
|
|
ts2 = time.time()
|
|||
|
|
processing_time = ts2 - ts1
|
|||
|
|
else:
|
|||
|
|
# 按照切分的音频依次输入
|
|||
|
|
for i, segment in enumerate(segments):
|
|||
|
|
segment_path = f"temp_seg_{i}.wav"
|
|||
|
|
torchaudio.save(segment_path, segment, sample_rate)
|
|||
|
|
ts1 = time.time()
|
|||
|
|
if model_type == "sense_voice":
|
|||
|
|
res = model.generate(
|
|||
|
|
input=segment_path,
|
|||
|
|
cache={},
|
|||
|
|
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
|||
|
|
use_itn=True,
|
|||
|
|
batch_size_s=60,
|
|||
|
|
merge_vad=False,
|
|||
|
|
# merge_length_s=15,
|
|||
|
|
)
|
|||
|
|
text = rich_transcription_postprocess(res[0]["text"])
|
|||
|
|
elif model_type == "whisper":
|
|||
|
|
DecodingOptions = {
|
|||
|
|
"task": "transcribe",
|
|||
|
|
"language": "zh",
|
|||
|
|
"beam_size": None,
|
|||
|
|
"fp16": False,
|
|||
|
|
"without_timestamps": False,
|
|||
|
|
"prompt": None,
|
|||
|
|
}
|
|||
|
|
res = model.generate(
|
|||
|
|
DecodingOptions=DecodingOptions,
|
|||
|
|
input=segment_path,
|
|||
|
|
batch_size_s=0,
|
|||
|
|
)
|
|||
|
|
text = res[0]["text"]
|
|||
|
|
elif model_type == "paraformer":
|
|||
|
|
res = model.generate(
|
|||
|
|
input=segment_path,
|
|||
|
|
batch_size_s=300
|
|||
|
|
)
|
|||
|
|
text = res[0]["text"]
|
|||
|
|
# paraformer模型会一个字一个字输出,中间夹太多空格会影响1-cer的结果
|
|||
|
|
text = text.replace(" ", "")
|
|||
|
|
elif model_type == "conformer":
|
|||
|
|
res = model.generate(
|
|||
|
|
input=segment_path,
|
|||
|
|
batch_size_s=300
|
|||
|
|
)
|
|||
|
|
text = res[0]["text"]
|
|||
|
|
# elif model_type == "uni_asr":
|
|||
|
|
# if i == 0:
|
|||
|
|
# os.remove(segment_path)
|
|||
|
|
# continue
|
|||
|
|
# res = model.generate(
|
|||
|
|
# input=segment_path
|
|||
|
|
# )
|
|||
|
|
# text = res[0]["text"]
|
|||
|
|
else:
|
|||
|
|
raise RuntimeError("unknown model type")
|
|||
|
|
ts2 = time.time()
|
|||
|
|
generated_text += text
|
|||
|
|
processing_time += (ts2 - ts1)
|
|||
|
|
os.remove(segment_path)
|
|||
|
|
|
|||
|
|
rtf = processing_time / duration
|
|||
|
|
print("Text:", generated_text, flush=True)
|
|||
|
|
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
|||
|
|
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
|
|||
|
|
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
|
|||
|
|
with open(answer_file, 'r', encoding='utf-8') as f:
|
|||
|
|
groundtruth_text = f.read()
|
|||
|
|
acc = cal_per_cer(generated_text, groundtruth_text, "zh")
|
|||
|
|
print(f"1-cer = {acc}", flush=True)
|
|||
|
|
|
|||
|
|
return processing_time, acc, generated_text
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
test_result = {
|
|||
|
|
"time_cuda": 0,
|
|||
|
|
"acc_cuda": 0,
|
|||
|
|
"text_cuda": "",
|
|||
|
|
"success": False
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
model_dict = {
|
|||
|
|
"sense_voice": "SenseVoiceSmall",
|
|||
|
|
"paraformer": "speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
|||
|
|
"conformer": "speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
|
|||
|
|
"whisper": "Whisper-large-v3",
|
|||
|
|
"uni_asr": "speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline"
|
|||
|
|
}
|
|||
|
|
LOCAL_TEST = os.getenv("LOCAL_TEST", "false").lower() == "true"
|
|||
|
|
K8S_TEST = os.getenv("K8S_TEST", "false").lower() == "true"
|
|||
|
|
workspace_path = "../" if LOCAL_TEST else "/tmp/workspace"
|
|||
|
|
model_dir = os.path.join("/model", model_dict["sense_voice"]) if LOCAL_TEST else os.environ["MODEL_DIR"]
|
|||
|
|
audio_file = "lei-jun-test.wav" if LOCAL_TEST else os.path.join(workspace_path, os.environ["TEST_FILE"])
|
|||
|
|
answer_file = "lei-jun.txt" if LOCAL_TEST else os.path.join(workspace_path, os.environ["ANSWER_FILE"])
|
|||
|
|
result_file = "result.json" if LOCAL_TEST else os.path.join(workspace_path, os.environ["RESULT_FILE"])
|
|||
|
|
# test_funasr(model_dir, audio_file, answer_file, False)
|
|||
|
|
processing_time, acc, generated_text = test_funasr(model_dir, audio_file, answer_file, True)
|
|||
|
|
test_result["time_cuda"] = processing_time
|
|||
|
|
test_result["acc_cuda"] = acc
|
|||
|
|
test_result["text_cuda"] = generated_text
|
|||
|
|
test_result["success"] = True
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"ASR测试出错: {e}", flush=True)
|
|||
|
|
with open(result_file, "w", encoding="utf-8") as fp:
|
|||
|
|
json.dump(test_result, fp, ensure_ascii=False, indent=4)
|
|||
|
|
# 如果是SUT起来镜像的话,需要加上下面让pod永不停止以迎合k8s deployment, 本地测试以及docker run均不需要
|
|||
|
|
if K8S_TEST:
|
|||
|
|
print(f"Start to sleep indefinitely", flush=True)
|
|||
|
|
time.sleep(100000)
|