Files
enginex-ascend-910-asr/test_funasr.py
2025-09-04 11:19:41 +08:00

255 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import torchaudio
import torch
import uuid
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from utils.calculate import cal_per_cer
import json
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "").lower()
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
def make_all_dense(module: torch.nn.Module):
for name, param in list(module.named_parameters(recurse=True)):
if getattr(param, "is_sparse", False) and param.is_sparse:
with torch.no_grad():
dense = param.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
setattr(parent, leaf, torch.nn.Parameter(dense, requires_grad=param.requires_grad))
# 处理 buffer如 running_mean 等)
for name, buf in list(module.named_buffers(recurse=True)):
# PyTorch 稀疏张量 layout 不是 strided
if buf.layout != torch.strided:
dense = buf.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
parent.register_buffer(leaf, dense, persistent=True)
def split_audio(waveform, sample_rate, segment_seconds=20):
segment_samples = segment_seconds * sample_rate
segments = []
for i in range(0, waveform.shape[1], segment_samples):
segment = waveform[:, i:i + segment_samples]
if segment.shape[1] > 0:
segments.append(segment)
return segments
def determine_model_type(model_name):
if "sensevoice" in model_name.lower():
return "sense_voice"
elif "whisper" in model_name.lower():
return "whisper"
elif "paraformer" in model_name.lower():
return "paraformer"
elif "conformer" in model_name.lower():
return "conformer"
elif "uniasr" in model_name.lower():
return "uni_asr"
else:
return "unknown"
def test_funasr(model_dir, audio_file, answer_file, use_gpu):
def warmup(segment, sample_rate, model):
print("Start warmup...", flush=True)
temp_file = f"{str(uuid.uuid4())}.wav"
torchaudio.save(temp_file, segment, sample_rate)
res = model.generate(input=temp_file)
os.remove(temp_file)
print("warmup complete.", flush=True)
model_name = os.path.basename(model_dir)
model_type = determine_model_type(model_name)
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:1"
else:
device = "cuda:0"
# 针对加速卡的特殊处理部分
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
# 天垓100情况下的Whisper需要绕过不支持算子
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
print(f"device: {device}", flush=True)
dense_convert = False
if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper":
dense_convert = True
if device.startswith("npu") and model_type == "whisper":
# Ascend NPU 加载whisper的部分会有Sparse部分device不匹配
dense_convert = True
print(f"dense_convert: {dense_convert}", flush=True)
if dense_convert:
model = AutoModel(
model=model_dir,
vad_model=None,
disable_update=True,
device="cpu"
)
make_all_dense(model.model)
model.model.to(dtype=torch.float32, memory_format=torch.contiguous_format)
model.model.to(device)
model.kwargs["device"] = device
else:
# 不使用VAD, punctspk模型就测试原始ASR能力
model = AutoModel(
model=model_dir,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
vad_model=None,
device=device,
disable_update=True
)
# 推理部分
waveform, sample_rate = torchaudio.load(audio_file)
# print(waveform.shape)
duration = waveform.shape[1] / sample_rate
segments = split_audio(waveform, sample_rate, segment_seconds=20)
if device.startswith("npu"):
# Ascend NPU由于底层设计的不同初始化卡的调度比其他卡更复杂要先进行warmup
warmup(segments[1], sample_rate, model)
generated_text = ""
processing_time = 0
if model_type == "uni_asr":
# uni_asr比较特殊设计就是处理长音频的自带VAD切分的话前20s如果几乎没有人讲话全是音乐直接会报错
# 因为可能会被切掉所有音频导致实际编解码输入为0
ts1 = time.time()
res = model.generate(
input=audio_file
)
generated_text = res[0]["text"]
ts2 = time.time()
processing_time = ts2 - ts1
else:
# 按照切分的音频依次输入
for i, segment in enumerate(segments):
segment_path = f"temp_seg_{i}.wav"
torchaudio.save(segment_path, segment, sample_rate)
ts1 = time.time()
if model_type == "sense_voice":
res = model.generate(
input=segment_path,
cache={},
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=False,
# merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
elif model_type == "whisper":
DecodingOptions = {
"task": "transcribe",
"language": "zh",
"beam_size": None,
"fp16": False,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
DecodingOptions=DecodingOptions,
input=segment_path,
batch_size_s=0,
)
text = res[0]["text"]
elif model_type == "paraformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# paraformer模型会一个字一个字输出中间夹太多空格会影响1-cer的结果
text = text.replace(" ", "")
elif model_type == "conformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# elif model_type == "uni_asr":
# if i == 0:
# os.remove(segment_path)
# continue
# res = model.generate(
# input=segment_path
# )
# text = res[0]["text"]
else:
raise RuntimeError("unknown model type")
ts2 = time.time()
generated_text += text
processing_time += (ts2 - ts1)
os.remove(segment_path)
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
with open(answer_file, 'r', encoding='utf-8') as f:
groundtruth_text = f.read()
acc = cal_per_cer(generated_text, groundtruth_text, "zh")
print(f"1-cer = {acc}", flush=True)
return processing_time, acc, generated_text
if __name__ == "__main__":
test_result = {
"time_cuda": 0,
"acc_cuda": 0,
"text_cuda": "",
"success": False
}
try:
model_dict = {
"sense_voice": "SenseVoiceSmall",
"paraformer": "speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
"conformer": "speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
"whisper": "Whisper-large-v3",
"uni_asr": "speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline"
}
LOCAL_TEST = os.getenv("LOCAL_TEST", "false").lower() == "true"
K8S_TEST = os.getenv("K8S_TEST", "false").lower() == "true"
workspace_path = "../" if LOCAL_TEST else "/tmp/workspace"
model_dir = os.path.join("/model", model_dict["uni_asr"]) if LOCAL_TEST else os.environ["MODEL_DIR"]
audio_file = "lei-jun-test.wav" if LOCAL_TEST else os.path.join(workspace_path, os.environ["TEST_FILE"])
answer_file = "lei-jun.txt" if LOCAL_TEST else os.path.join(workspace_path, os.environ["ANSWER_FILE"])
result_file = "result.json" if LOCAL_TEST else os.path.join(workspace_path, os.environ["RESULT_FILE"])
# test_funasr(model_dir, audio_file, answer_file, False)
processing_time, acc, generated_text = test_funasr(model_dir, audio_file, answer_file, True)
test_result["time_cuda"] = processing_time
test_result["acc_cuda"] = acc
test_result["text_cuda"] = generated_text
test_result["success"] = True
except Exception as e:
print(f"ASR测试出错: {e}", flush=True)
with open(result_file, "w", encoding="utf-8") as fp:
json.dump(test_result, fp, ensure_ascii=False, indent=4)
# 如果是SUT起来镜像的话需要加上下面让pod永不停止以迎合k8s deployment, 本地测试以及docker run均不需要
if K8S_TEST:
print(f"Start to sleep indefinitely", flush=True)
time.sleep(100000)