import os import time import argparse import torchaudio import torch import traceback from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form import uuid import uvicorn from funasr import AutoModel from funasr.utils.postprocess_utils import rich_transcription_postprocess from funasr.models.fun_asr_nano.model import FunASRNano os.makedirs("./input", exist_ok=True) status = "Running" model = None device = "" app = FastAPI() CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "") if CUSTOM_DEVICE.startswith("mlu"): import torch_mlu elif CUSTOM_DEVICE.startswith("ascend"): import torch_npu elif CUSTOM_DEVICE.startswith("pt"): import torch_dipu def make_all_dense(module: torch.nn.Module): for name, param in list(module.named_parameters(recurse=True)): if getattr(param, "is_sparse", False) and param.is_sparse: with torch.no_grad(): dense = param.to_dense().contiguous() parent = module *mods, leaf = name.split(".") for m in mods: parent = getattr(parent, m) setattr(parent, leaf, torch.nn.Parameter(dense, requires_grad=param.requires_grad)) # 处理 buffer(如 running_mean 等) for name, buf in list(module.named_buffers(recurse=True)): # PyTorch 稀疏张量 layout 不是 strided if buf.layout != torch.strided: dense = buf.to_dense().contiguous() parent = module *mods, leaf = name.split(".") for m in mods: parent = getattr(parent, m) parent.register_buffer(leaf, dense, persistent=True) def split_audio(waveform, sample_rate, segment_seconds=20): segment_samples = segment_seconds * sample_rate segments = [] for i in range(0, waveform.shape[1], segment_samples): segment = waveform[:, i:i + segment_samples] if segment.shape[1] > 0: segments.append(segment) return segments # def determine_model_type(model_name): # if "sensevoice" in model_name.lower(): # return "sensevoice" # elif "whisper" in model_name.lower(): # return "whisper" # elif "paraformer" in model_name.lower(): # return "paraformer" # elif "conformer" in model_name.lower(): # return "conformer" # elif "uniasr" in model_name.lower(): # return "uni_asr" # else: # return "unknown" @app.on_event("startup") def load_model(): global status, model, device config = app.state.config use_gpu = config.get("use_gpu", True) model_dir = config.get("model_dir", "/model") model_type = config.get("model_type", "sensevoice") warmup = config.get("warmup", False) print(">> Startup config:") print(" model_dir =", model_dir, flush=True) print(" model_type =", model_type, flush=True) print(" use_gpu =", use_gpu, flush=True) print(" warmup =", warmup, flush=True) device = "cpu" if use_gpu: if CUSTOM_DEVICE.startswith("mlu"): device = "mlu:0" elif CUSTOM_DEVICE.startswith("ascend"): device = "npu:1" else: device = "cuda:0" # 针对加速卡的特殊处理部分 if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper": # 天垓100情况下的Whisper需要绕过不支持算子 torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) print(f"device: {device}", flush=True) dense_convert = False if device == "cuda:0" and CUSTOM_DEVICE.startswith("pt") and model_type == "whisper": dense_convert = True if device.startswith("npu") and model_type == "whisper": # Ascend NPU 加载whisper的部分会有Sparse部分device不匹配 dense_convert = True print(f"dense_convert: {dense_convert}", flush=True) if dense_convert: model = AutoModel( model=model_dir, vad_model=None, disable_update=True, device="cpu" ) make_all_dense(model.model) model.model.to(dtype=torch.float32, memory_format=torch.contiguous_format) model.model.to(device) model.kwargs["device"] = device else: # 不使用VAD, punct,spk模型,就测试原始ASR能力 model = AutoModel( model=model_dir, # vad_model="fsmn-vad", # vad_kwargs={"max_single_segment_time": 30000}, vad_model=None, device=device, disable_update=True ) if device.startswith("npu") or warmup: # Ascend NPU由于底层设计的不同,初始化卡的调度比其他卡更复杂,要先进行warmup print("Start warmup...", flush=True) res = model.generate(input="warmup.wav") print("warmup complete.", flush=True) status = "Success" def test_funasr(audio_file, lang): # 推理部分 waveform, sample_rate = torchaudio.load(audio_file) # print(waveform.shape) duration = waveform.shape[1] / sample_rate segments = split_audio(waveform, sample_rate, segment_seconds=20) generated_text = "" processing_time = 0 model_type = app.state.config.get("model_type", "sensevoice") if model_type == "uni_asr": # uni_asr比较特殊,设计就是处理长音频的(自带VAD),切分的话前20s如果几乎没有人讲话全是音乐直接会报错 # 因为可能会被切掉所有音频导致实际编解码输入为0 ts1 = time.time() res = model.generate( input=audio_file ) generated_text = res[0]["text"] ts2 = time.time() processing_time = ts2 - ts1 else: # 按照切分的音频依次输入 for i, segment in enumerate(segments): segment_path = f"temp_seg_{i}.wav" torchaudio.save(segment_path, segment, sample_rate) ts1 = time.time() if model_type == "sensevoice": res = model.generate( input=segment_path, cache={}, language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" use_itn=True, batch_size_s=60, merge_vad=False, # merge_length_s=15, ) text = rich_transcription_postprocess(res[0]["text"]) elif model_type == "whisper": DecodingOptions = { "task": "transcribe", "language": lang, "beam_size": None, "fp16": False, "without_timestamps": False, "prompt": None, } res = model.generate( DecodingOptions=DecodingOptions, input=segment_path, batch_size_s=0, ) text = res[0]["text"] elif model_type == "paraformer": res = model.generate( input=segment_path, batch_size_s=300 ) text = res[0]["text"] # paraformer模型会一个字一个字输出,中间夹太多空格会影响1-cer的结果 text = text.replace(" ", "") elif model_type == "conformer": res = model.generate( input=segment_path, batch_size_s=300 ) text = res[0]["text"] # elif model_type == "uni_asr": # if i == 0: # os.remove(segment_path) # continue # res = model.generate( # input=segment_path # ) # text = res[0]["text"] else: raise RuntimeError("unknown model type") ts2 = time.time() generated_text += text processing_time += (ts2 - ts1) os.remove(segment_path) rtf = processing_time / duration print("Text:", generated_text, flush=True) print(f"Audio duration:\t{duration:.3f} s", flush=True) print(f"Elapsed:\t{processing_time:.3f} s", flush=True) print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True) return generated_text @app.get("/health") def health(): if status=="Running": return { "status":"loading model" } ret = { "status": "ok" if status == "Success" else "failed", } return ret @app.post("/transduce") def transduce( audio: UploadFile = File(...), lang: str = Form("zh"), background_tasks: BackgroundTasks = None ): try: file_path = f"./input/{uuid.uuid4()}.wav" with open(file_path, "wb") as f: f.write(audio.file.read()) background_tasks.add_task(os.remove, file_path) generated_text = test_funasr(file_path, lang) return {"generated_text": generated_text} except Exception: raise HTTPException(status_code=500, detail=f"Processing failed: \n{traceback.format_exc()}") # if __name__ == "__main__": # uvicorn.run("fastapi_funasr:app", host="0.0.0.0", port=1111, workers=1)