first revise
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
import torchaudio
|
||||
import torch
|
||||
import torchaudio
|
||||
from funasr import AutoModel
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
from utils.calculate import cal_per_cer
|
||||
@@ -34,18 +34,11 @@ def test_funasr(model_dir, audio_file, answer_file, use_gpu):
|
||||
model_name = os.path.basename(model_dir)
|
||||
model_type = determine_model_type(model_name)
|
||||
|
||||
if torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
|
||||
# 天垓100情况下的Whisper需要绕过不支持算子
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
|
||||
# 不使用VAD, punct,spk模型,就测试原始ASR能力
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
# vad_model="fsmn-vad",
|
||||
# vad_kwargs={"max_single_segment_time": 30000},
|
||||
vad_model=None,
|
||||
device="cuda:0" if use_gpu else "cpu",
|
||||
disable_update=True
|
||||
)
|
||||
@@ -114,14 +107,6 @@ def test_funasr(model_dir, audio_file, answer_file, use_gpu):
|
||||
batch_size_s=300
|
||||
)
|
||||
text = res[0]["text"]
|
||||
# elif model_type == "uni_asr":
|
||||
# if i == 0:
|
||||
# os.remove(segment_path)
|
||||
# continue
|
||||
# res = model.generate(
|
||||
# input=segment_path
|
||||
# )
|
||||
# text = res[0]["text"]
|
||||
else:
|
||||
raise RuntimeError("unknown model type")
|
||||
ts2 = time.time()
|
||||
@@ -142,6 +127,13 @@ def test_funasr(model_dir, audio_file, answer_file, use_gpu):
|
||||
return processing_time, acc, generated_text
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.cuda.is_available():
|
||||
cuda_tensor = torch.randn(2, 2, device='cuda:0')
|
||||
print(f"CUDA device index: {cuda_tensor.get_device()}")
|
||||
else:
|
||||
print("CUDA not available")
|
||||
os._exit(1)
|
||||
|
||||
test_result = {
|
||||
"time_cuda": 0,
|
||||
"acc_cuda": 0,
|
||||
@@ -176,5 +168,5 @@ if __name__ == "__main__":
|
||||
json.dump(test_result, fp, ensure_ascii=False, indent=4)
|
||||
# 如果是SUT起来镜像的话,需要加上下面让pod永不停止以迎合k8s deployment, 本地测试以及docker run均不需要
|
||||
if K8S_TEST:
|
||||
print(f"Start to sleep indefinitely", flush=True)
|
||||
time.sleep(100000)
|
||||
print(f"Start to sleep indefinity", flush=True)
|
||||
time.sleep(100000)
|
||||
|
||||
Reference in New Issue
Block a user