Add fastapi service
This commit is contained in:
265
fastapi_funasr.py
Normal file
265
fastapi_funasr.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import torchaudio
|
||||
import torch
|
||||
import traceback
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
|
||||
import uuid
|
||||
import uvicorn
|
||||
from funasr import AutoModel
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
from funasr.models.fun_asr_nano.model import FunASRNano
|
||||
|
||||
os.makedirs("./input", exist_ok=True)
|
||||
status = "Running"
|
||||
model = None
|
||||
device = ""
|
||||
app = FastAPI()
|
||||
|
||||
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "")
|
||||
if CUSTOM_DEVICE.startswith("mlu"):
|
||||
import torch_mlu
|
||||
elif CUSTOM_DEVICE.startswith("ascend"):
|
||||
import torch_npu
|
||||
elif CUSTOM_DEVICE.startswith("pt"):
|
||||
import torch_dipu
|
||||
|
||||
def make_all_dense(module: torch.nn.Module):
|
||||
for name, param in list(module.named_parameters(recurse=True)):
|
||||
if getattr(param, "is_sparse", False) and param.is_sparse:
|
||||
with torch.no_grad():
|
||||
dense = param.to_dense().contiguous()
|
||||
parent = module
|
||||
*mods, leaf = name.split(".")
|
||||
for m in mods:
|
||||
parent = getattr(parent, m)
|
||||
setattr(parent, leaf, torch.nn.Parameter(dense, requires_grad=param.requires_grad))
|
||||
|
||||
# 处理 buffer(如 running_mean 等)
|
||||
for name, buf in list(module.named_buffers(recurse=True)):
|
||||
# PyTorch 稀疏张量 layout 不是 strided
|
||||
if buf.layout != torch.strided:
|
||||
dense = buf.to_dense().contiguous()
|
||||
parent = module
|
||||
*mods, leaf = name.split(".")
|
||||
for m in mods:
|
||||
parent = getattr(parent, m)
|
||||
parent.register_buffer(leaf, dense, persistent=True)
|
||||
|
||||
def split_audio(waveform, sample_rate, segment_seconds=20):
|
||||
segment_samples = segment_seconds * sample_rate
|
||||
segments = []
|
||||
for i in range(0, waveform.shape[1], segment_samples):
|
||||
segment = waveform[:, i:i + segment_samples]
|
||||
if segment.shape[1] > 0:
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
# def determine_model_type(model_name):
|
||||
# if "sensevoice" in model_name.lower():
|
||||
# return "sensevoice"
|
||||
# elif "whisper" in model_name.lower():
|
||||
# return "whisper"
|
||||
# elif "paraformer" in model_name.lower():
|
||||
# return "paraformer"
|
||||
# elif "conformer" in model_name.lower():
|
||||
# return "conformer"
|
||||
# elif "uniasr" in model_name.lower():
|
||||
# return "uni_asr"
|
||||
# else:
|
||||
# return "unknown"
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_model():
|
||||
global status, model, device
|
||||
|
||||
config = app.state.config
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
model_dir = config.get("model_dir", "/model")
|
||||
model_type = config.get("model_type", "sensevoice")
|
||||
warmup = config.get("warmup", False)
|
||||
print(">> Startup config:")
|
||||
print(" model_dir =", model_dir, flush=True)
|
||||
print(" model_type =", model_type, flush=True)
|
||||
print(" use_gpu =", use_gpu, flush=True)
|
||||
print(" warmup =", warmup, flush=True)
|
||||
|
||||
device = "cpu"
|
||||
if use_gpu:
|
||||
if CUSTOM_DEVICE.startswith("mlu"):
|
||||
device = "mlu:0"
|
||||
elif CUSTOM_DEVICE.startswith("ascend"):
|
||||
device = "npu:1"
|
||||
else:
|
||||
device = "cuda:0"
|
||||
|
||||
# 针对加速卡的特殊处理部分
|
||||
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
|
||||
# 天垓100情况下的Whisper需要绕过不支持算子
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
print(f"device: {device}", flush=True)
|
||||
|
||||
dense_convert = False
|
||||
if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper":
|
||||
dense_convert = True
|
||||
if device.startswith("npu") and model_type == "whisper":
|
||||
# Ascend NPU 加载whisper的部分会有Sparse部分device不匹配
|
||||
dense_convert = True
|
||||
|
||||
print(f"dense_convert: {dense_convert}", flush=True)
|
||||
if dense_convert:
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
vad_model=None,
|
||||
disable_update=True,
|
||||
device="cpu"
|
||||
)
|
||||
make_all_dense(model.model)
|
||||
model.model.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
||||
model.model.to(device)
|
||||
model.kwargs["device"] = device
|
||||
else:
|
||||
# 不使用VAD, punct,spk模型,就测试原始ASR能力
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
# vad_model="fsmn-vad",
|
||||
# vad_kwargs={"max_single_segment_time": 30000},
|
||||
vad_model=None,
|
||||
device=device,
|
||||
disable_update=True
|
||||
)
|
||||
|
||||
if device.startswith("npu") or warmup:
|
||||
# Ascend NPU由于底层设计的不同,初始化卡的调度比其他卡更复杂,要先进行warmup
|
||||
print("Start warmup...", flush=True)
|
||||
res = model.generate(input="warmup.wav")
|
||||
print("warmup complete.", flush=True)
|
||||
|
||||
status = "Success"
|
||||
|
||||
|
||||
def test_funasr(audio_file, lang):
|
||||
# 推理部分
|
||||
waveform, sample_rate = torchaudio.load(audio_file)
|
||||
# print(waveform.shape)
|
||||
duration = waveform.shape[1] / sample_rate
|
||||
segments = split_audio(waveform, sample_rate, segment_seconds=20)
|
||||
|
||||
generated_text = ""
|
||||
processing_time = 0
|
||||
model_type = app.state.config.get("model_type", "sensevoice")
|
||||
if model_type == "uni_asr":
|
||||
# uni_asr比较特殊,设计就是处理长音频的(自带VAD),切分的话前20s如果几乎没有人讲话全是音乐直接会报错
|
||||
# 因为可能会被切掉所有音频导致实际编解码输入为0
|
||||
ts1 = time.time()
|
||||
res = model.generate(
|
||||
input=audio_file
|
||||
)
|
||||
generated_text = res[0]["text"]
|
||||
ts2 = time.time()
|
||||
processing_time = ts2 - ts1
|
||||
else:
|
||||
# 按照切分的音频依次输入
|
||||
for i, segment in enumerate(segments):
|
||||
segment_path = f"temp_seg_{i}.wav"
|
||||
torchaudio.save(segment_path, segment, sample_rate)
|
||||
ts1 = time.time()
|
||||
if model_type == "sensevoice":
|
||||
res = model.generate(
|
||||
input=segment_path,
|
||||
cache={},
|
||||
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=True,
|
||||
batch_size_s=60,
|
||||
merge_vad=False,
|
||||
# merge_length_s=15,
|
||||
)
|
||||
text = rich_transcription_postprocess(res[0]["text"])
|
||||
elif model_type == "whisper":
|
||||
DecodingOptions = {
|
||||
"task": "transcribe",
|
||||
"language": lang,
|
||||
"beam_size": None,
|
||||
"fp16": False,
|
||||
"without_timestamps": False,
|
||||
"prompt": None,
|
||||
}
|
||||
res = model.generate(
|
||||
DecodingOptions=DecodingOptions,
|
||||
input=segment_path,
|
||||
batch_size_s=0,
|
||||
)
|
||||
text = res[0]["text"]
|
||||
elif model_type == "paraformer":
|
||||
res = model.generate(
|
||||
input=segment_path,
|
||||
batch_size_s=300
|
||||
)
|
||||
text = res[0]["text"]
|
||||
# paraformer模型会一个字一个字输出,中间夹太多空格会影响1-cer的结果
|
||||
text = text.replace(" ", "")
|
||||
elif model_type == "conformer":
|
||||
res = model.generate(
|
||||
input=segment_path,
|
||||
batch_size_s=300
|
||||
)
|
||||
text = res[0]["text"]
|
||||
# elif model_type == "uni_asr":
|
||||
# if i == 0:
|
||||
# os.remove(segment_path)
|
||||
# continue
|
||||
# res = model.generate(
|
||||
# input=segment_path
|
||||
# )
|
||||
# text = res[0]["text"]
|
||||
else:
|
||||
raise RuntimeError("unknown model type")
|
||||
ts2 = time.time()
|
||||
generated_text += text
|
||||
processing_time += (ts2 - ts1)
|
||||
os.remove(segment_path)
|
||||
|
||||
rtf = processing_time / duration
|
||||
print("Text:", generated_text, flush=True)
|
||||
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
||||
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
|
||||
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
|
||||
|
||||
return generated_text
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
|
||||
if status=="Running":
|
||||
return {
|
||||
"status":"loading model"
|
||||
}
|
||||
ret = {
|
||||
"status": "ok" if status == "Success" else "failed",
|
||||
}
|
||||
return ret
|
||||
|
||||
@app.post("/transduce")
|
||||
def transduce(
|
||||
audio: UploadFile = File(...),
|
||||
lang: str = Form("zh"),
|
||||
background_tasks: BackgroundTasks = None
|
||||
):
|
||||
try:
|
||||
file_path = f"./input/{uuid.uuid4()}.wav"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(audio.file.read())
|
||||
background_tasks.add_task(os.remove, file_path)
|
||||
generated_text = test_funasr(file_path, lang)
|
||||
|
||||
return {"generated_text": generated_text}
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
# uvicorn.run("fastapi_funasr:app", host="0.0.0.0", port=1111, workers=1)
|
||||
Reference in New Issue
Block a user