Files
enginex-mr_series-asr/fastapi_funasr.py

272 lines
9.5 KiB
Python
Raw Normal View History

2026-02-04 17:34:39 +08:00
import os
import time
import argparse
import torchaudio
import torch
import traceback
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
import uuid
import uvicorn
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from funasr.models.fun_asr_nano.model import FunASRNano
os.makedirs("./input", exist_ok=True)
status = "Running"
model = None
device = ""
app = FastAPI()
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
if CUSTOM_DEVICE.startswith("mlu"):
import torch_mlu
elif CUSTOM_DEVICE.startswith("ascend"):
import torch_npu
elif CUSTOM_DEVICE.startswith("pt"):
import torch_dipu
def make_all_dense(module: torch.nn.Module):
for name, param in list(module.named_parameters(recurse=True)):
if getattr(param, "is_sparse", False) and param.is_sparse:
with torch.no_grad():
dense = param.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
setattr(parent, leaf, torch.nn.Parameter(dense, requires_grad=param.requires_grad))
# 处理 buffer如 running_mean 等)
for name, buf in list(module.named_buffers(recurse=True)):
# PyTorch 稀疏张量 layout 不是 strided
if buf.layout != torch.strided:
dense = buf.to_dense().contiguous()
parent = module
*mods, leaf = name.split(".")
for m in mods:
parent = getattr(parent, m)
parent.register_buffer(leaf, dense, persistent=True)
def split_audio(waveform, sample_rate, segment_seconds=20):
segment_samples = segment_seconds * sample_rate
segments = []
for i in range(0, waveform.shape[1], segment_samples):
segment = waveform[:, i:i + segment_samples]
if segment.shape[1] > 0:
segments.append(segment)
return segments
# def determine_model_type(model_name):
# if "sensevoice" in model_name.lower():
# return "sensevoice"
# elif "whisper" in model_name.lower():
# return "whisper"
# elif "paraformer" in model_name.lower():
# return "paraformer"
# elif "conformer" in model_name.lower():
# return "conformer"
# elif "uniasr" in model_name.lower():
# return "uni_asr"
# else:
# return "unknown"
@app.on_event("startup")
def load_model():
global status, model, device
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
config = app.state.config
use_gpu = config.get("use_gpu", True)
model_dir = config.get("model_dir", "/model")
model_type = config.get("model_type", "sensevoice")
warmup = config.get("warmup", False)
print(">> Startup config:")
print(" model_dir =", model_dir, flush=True)
print(" model_type =", model_type, flush=True)
print(" use_gpu =", use_gpu, flush=True)
print(" warmup =", warmup, flush=True)
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
device = "cpu"
if use_gpu:
if CUSTOM_DEVICE.startswith("mlu"):
device = "mlu:0"
elif CUSTOM_DEVICE.startswith("ascend"):
device = "npu:1"
else:
device = "cuda:0"
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
# 针对加速卡的特殊处理部分
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
# 天垓100情况下的Whisper需要绕过不支持算子
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
print(f"device: {device}", flush=True)
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
dense_convert = False
if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper":
dense_convert = True
if device.startswith("npu") and model_type == "whisper":
# Ascend NPU 加载whisper的部分会有Sparse部分device不匹配
dense_convert = True
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
print(f"dense_convert: {dense_convert}", flush=True)
if dense_convert:
model = AutoModel(
model=model_dir,
vad_model=None,
disable_update=True,
device="cpu"
)
make_all_dense(model.model)
model.model.to(dtype=torch.float32, memory_format=torch.contiguous_format)
model.model.to(device)
model.kwargs["device"] = device
else:
# 不使用VAD, punctspk模型就测试原始ASR能力
model = AutoModel(
model=model_dir,
# vad_model="fsmn-vad",
# vad_kwargs={"max_single_segment_time": 30000},
vad_model=None,
device=device,
disable_update=True
)
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
if device.startswith("npu") or warmup:
# Ascend NPU由于底层设计的不同初始化卡的调度比其他卡更复杂要先进行warmup
print("Start warmup...", flush=True)
res = model.generate(input="warmup.wav")
print("warmup complete.", flush=True)
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
status = "Success"
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
def test_funasr(audio_file, lang):
# 推理部分
waveform, sample_rate = torchaudio.load(audio_file)
# print(waveform.shape)
duration = waveform.shape[1] / sample_rate
segments = split_audio(waveform, sample_rate, segment_seconds=20)
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
generated_text = ""
processing_time = 0
model_type = app.state.config.get("model_type", "sensevoice")
if model_type == "uni_asr":
# uni_asr比较特殊设计就是处理长音频的自带VAD切分的话前20s如果几乎没有人讲话全是音乐直接会报错
# 因为可能会被切掉所有音频导致实际编解码输入为0
ts1 = time.time()
res = model.generate(
input=audio_file
)
generated_text = res[0]["text"]
ts2 = time.time()
processing_time = ts2 - ts1
else:
# 按照切分的音频依次输入
for i, segment in enumerate(segments):
segment_path = f"temp_seg_{i}.wav"
torchaudio.save(segment_path, segment, sample_rate)
ts1 = time.time()
text = None
2026-02-04 17:34:39 +08:00
if model_type == "sensevoice":
res = model.generate(
input=segment_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=False,
# merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
elif model_type == "whisper":
DecodingOptions = {
"task": "transcribe",
"language": lang,
"beam_size": None,
"fp16": False,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
DecodingOptions=DecodingOptions,
input=segment_path,
batch_size_s=0,
)
text = res[0]["text"]
elif model_type == "paraformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# paraformer模型会一个字一个字输出中间夹太多空格会影响1-cer的结果
2026-02-09 13:48:45 +08:00
if lang == "zh":
text = text.replace(" ", "")
2026-02-04 17:34:39 +08:00
elif model_type == "conformer":
res = model.generate(
input=segment_path,
batch_size_s=300
)
text = res[0]["text"]
# elif model_type == "uni_asr":
# if i == 0:
# os.remove(segment_path)
# continue
# res = model.generate(
# input=segment_path
# )
# text = res[0]["text"]
else:
raise RuntimeError("unknown model type")
if text is not None:
# some models output "▁" (9601, Unicode U+2581) as separator between words, replace them with space for better readability
text = text.replace("_", " ")
text = text.replace(chr(9601), " ")
2026-02-04 17:34:39 +08:00
ts2 = time.time()
generated_text += text
processing_time += (ts2 - ts1)
os.remove(segment_path)
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
rtf = processing_time / duration
print("Text:", generated_text, flush=True)
print(f"Audio duration:\t{duration:.3f} s", flush=True)
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
return generated_text
@app.get("/health")
def health():
if status=="Running":
return {
"status":"loading model"
}
ret = {
"status": "ok" if status == "Success" else "failed",
}
return ret
@app.post("/transduce")
def transduce(
audio: UploadFile = File(...),
lang: str = Form("zh"),
background_tasks: BackgroundTasks = None
):
try:
file_path = f"./input/{uuid.uuid4()}.wav"
with open(file_path, "wb") as f:
f.write(audio.file.read())
background_tasks.add_task(os.remove, file_path)
generated_text = test_funasr(file_path, lang)
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
return {"generated_text": generated_text}
except Exception:
2026-02-09 13:48:45 +08:00
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
2026-02-04 17:34:39 +08:00
# if __name__ == "__main__":
2026-02-09 13:48:45 +08:00
2026-02-04 17:34:39 +08:00
# uvicorn.run("fastapi_funasr:app", host="0.0.0.0", port=1111, workers=1)