180 lines
7.2 KiB
Python
180 lines
7.2 KiB
Python
import os
|
||
import time
|
||
import torchaudio
|
||
import torch
|
||
from funasr import AutoModel
|
||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||
from utils.calculate import cal_per_cer
|
||
import json
|
||
|
||
def split_audio(waveform, sample_rate, segment_seconds=20):
|
||
segment_samples = segment_seconds * sample_rate
|
||
segments = []
|
||
for i in range(0, waveform.shape[1], segment_samples):
|
||
segment = waveform[:, i:i + segment_samples]
|
||
if segment.shape[1] > 0:
|
||
segments.append(segment)
|
||
return segments
|
||
|
||
def determine_model_type(model_name):
|
||
if "sensevoice" in model_name.lower():
|
||
return "sense_voice"
|
||
elif "whisper" in model_name.lower():
|
||
return "whisper"
|
||
elif "paraformer" in model_name.lower():
|
||
return "paraformer"
|
||
elif "conformer" in model_name.lower():
|
||
return "conformer"
|
||
elif "uniasr" in model_name.lower():
|
||
return "uni_asr"
|
||
else:
|
||
return "unknown"
|
||
|
||
def test_funasr(model_dir, audio_file, answer_file, use_gpu):
|
||
model_name = os.path.basename(model_dir)
|
||
model_type = determine_model_type(model_name)
|
||
|
||
if torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
|
||
# 天垓100情况下的Whisper需要绕过不支持算子
|
||
torch.backends.cuda.enable_flash_sdp(False)
|
||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||
torch.backends.cuda.enable_math_sdp(True)
|
||
|
||
# 不使用VAD, punct,spk模型,就测试原始ASR能力
|
||
model = AutoModel(
|
||
model=model_dir,
|
||
# vad_model="fsmn-vad",
|
||
# vad_kwargs={"max_single_segment_time": 30000},
|
||
vad_model=None,
|
||
device="cuda:0" if use_gpu else "cpu",
|
||
disable_update=True
|
||
)
|
||
|
||
waveform, sample_rate = torchaudio.load(audio_file)
|
||
# print(waveform.shape)
|
||
duration = waveform.shape[1] / sample_rate
|
||
segments = split_audio(waveform, sample_rate, segment_seconds=20)
|
||
|
||
generated_text = ""
|
||
processing_time = 0
|
||
|
||
if model_type == "uni_asr":
|
||
# uni_asr比较特殊,设计就是处理长音频的(自带VAD),切分的话前20s如果几乎没有人讲话全是音乐直接会报错
|
||
# 因为可能会被切掉所有音频导致实际编解码输入为0
|
||
ts1 = time.time()
|
||
res = model.generate(
|
||
input=audio_file
|
||
)
|
||
generated_text = res[0]["text"]
|
||
ts2 = time.time()
|
||
processing_time = ts2 - ts1
|
||
else:
|
||
# 按照切分的音频依次输入
|
||
for i, segment in enumerate(segments):
|
||
segment_path = f"temp_seg_{i}.wav"
|
||
torchaudio.save(segment_path, segment, sample_rate)
|
||
ts1 = time.time()
|
||
if model_type == "sense_voice":
|
||
res = model.generate(
|
||
input=segment_path,
|
||
cache={},
|
||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||
use_itn=True,
|
||
batch_size_s=60,
|
||
merge_vad=False,
|
||
# merge_length_s=15,
|
||
)
|
||
text = rich_transcription_postprocess(res[0]["text"])
|
||
elif model_type == "whisper":
|
||
DecodingOptions = {
|
||
"task": "transcribe",
|
||
"language": "zh",
|
||
"beam_size": None,
|
||
"fp16": False,
|
||
"without_timestamps": False,
|
||
"prompt": None,
|
||
}
|
||
res = model.generate(
|
||
DecodingOptions=DecodingOptions,
|
||
input=segment_path,
|
||
batch_size_s=0,
|
||
)
|
||
text = res[0]["text"]
|
||
elif model_type == "paraformer":
|
||
res = model.generate(
|
||
input=segment_path,
|
||
batch_size_s=300
|
||
)
|
||
text = res[0]["text"]
|
||
# paraformer模型会一个字一个字输出,中间夹太多空格会影响1-cer的结果
|
||
text = text.replace(" ", "")
|
||
elif model_type == "conformer":
|
||
res = model.generate(
|
||
input=segment_path,
|
||
batch_size_s=300
|
||
)
|
||
text = res[0]["text"]
|
||
# elif model_type == "uni_asr":
|
||
# if i == 0:
|
||
# os.remove(segment_path)
|
||
# continue
|
||
# res = model.generate(
|
||
# input=segment_path
|
||
# )
|
||
# text = res[0]["text"]
|
||
else:
|
||
raise RuntimeError("unknown model type")
|
||
ts2 = time.time()
|
||
generated_text += text
|
||
processing_time += (ts2 - ts1)
|
||
os.remove(segment_path)
|
||
|
||
rtf = processing_time / duration
|
||
print("Text:", generated_text, flush=True)
|
||
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
||
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
|
||
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
|
||
with open(answer_file, 'r', encoding='utf-8') as f:
|
||
groundtruth_text = f.read()
|
||
acc = cal_per_cer(generated_text, groundtruth_text, "zh")
|
||
print(f"1-cer = {acc}", flush=True)
|
||
|
||
return processing_time, acc, generated_text
|
||
|
||
if __name__ == "__main__":
|
||
test_result = {
|
||
"time_cuda": 0,
|
||
"acc_cuda": 0,
|
||
"text_cuda": "",
|
||
"success": False
|
||
}
|
||
|
||
try:
|
||
model_dict = {
|
||
"sense_voice": "SenseVoiceSmall",
|
||
"paraformer": "speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||
"conformer": "speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
|
||
"whisper": "Whisper-large-v3",
|
||
"uni_asr": "speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline"
|
||
}
|
||
LOCAL_TEST = os.getenv("LOCAL_TEST", "false").lower() == "true"
|
||
K8S_TEST = os.getenv("K8S_TEST", "false").lower() == "true"
|
||
workspace_path = "../" if LOCAL_TEST else "/tmp/workspace"
|
||
model_dir = os.path.join("/model", model_dict["sense_voice"]) if LOCAL_TEST else os.environ["MODEL_DIR"]
|
||
audio_file = "lei-jun-test.wav" if LOCAL_TEST else os.path.join(workspace_path, os.environ["TEST_FILE"])
|
||
answer_file = "lei-jun.txt" if LOCAL_TEST else os.path.join(workspace_path, os.environ["ANSWER_FILE"])
|
||
result_file = "result.json" if LOCAL_TEST else os.path.join(workspace_path, os.environ["RESULT_FILE"])
|
||
# test_funasr(model_dir, audio_file, answer_file, False)
|
||
processing_time, acc, generated_text = test_funasr(model_dir, audio_file, answer_file, True)
|
||
test_result["time_cuda"] = processing_time
|
||
test_result["acc_cuda"] = acc
|
||
test_result["text_cuda"] = generated_text
|
||
test_result["success"] = True
|
||
except Exception as e:
|
||
print(f"ASR测试出错: {e}", flush=True)
|
||
with open(result_file, "w", encoding="utf-8") as fp:
|
||
json.dump(test_result, fp, ensure_ascii=False, indent=4)
|
||
# 如果是SUT起来镜像的话,需要加上下面让pod永不停止以迎合k8s deployment, 本地测试以及docker run均不需要
|
||
if K8S_TEST:
|
||
print(f"Start to sleep indefinitely", flush=True)
|
||
time.sleep(100000) |