import os model_dir = os.getenv("MODEL_DIR", "/mounted_model") model_name = os.getenv("MODEL_NAME", "model.ckpt") import logging logging.basicConfig( format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO"), ) logger = logging.getLogger(__file__) import wave import numpy as np from scipy.signal import resample import re from fastapi import FastAPI, Response, Body, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn import xml.etree.ElementTree as ET import torch torch.set_default_dtype(torch.float32) _original_hann_window = torch.hann_window def _safe_hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False, **kwargs): """ NPU 不支持int64 hann_window, 替换实现 """ if dtype is None: dtype = torch.float32 # 总是在 CPU 先生成,绕过 NPU 上的 in-place cos 实现 win = _original_hann_window( window_length, periodic=periodic, dtype=dtype, layout=layout, device="cpu", requires_grad=requires_grad, **kwargs, ) if device is not None: win = win.to(device) return win torch.hann_window = _safe_hann_window from torch import Tensor from torch.nn import functional as F from typing import Optional, List from matcha.cli import load_matcha, load_vocoder, to_waveform, process_text model = None vocoder = None denoiser = None device = 'npu' MODEL_SR = int(os.getenv("MODEL_SR", 22050)) speaking_rate = float(os.getenv("SPEAKING_RATE", 1.0)) TARGET_SR = 16000 N_ZEROS = 100 def init(): global model, vocoder, denoiser ckpt_path = os.path.join(model_dir, model_name) vocoder_path = os.path.join(model_dir, "generator_v1") model = load_matcha("custom_model", ckpt_path, device) vocoder, denoiser = load_vocoder("hifigan_T2_v1", vocoder_path, device) # warmup: for _ in generate("你好,欢迎使用语音合成服务。"): pass @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) xml_namespace = "{http://www.w3.org/XML/1998/namespace}" symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—' def contains_words(text): return any(char not in symbols for char in text) def split_text(text, max_chars=135): sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text) sentences = [s.strip() for s in sentences if s.strip()] return sentences def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray: if ori_sr != target_sr: number_of_samples = int(len(audio) * float(target_sr) / ori_sr) audio_resampled = resample(audio, number_of_samples) else: audio_resampled = audio if audio.dtype == np.float32: audio_resampled = np.clip(audio_resampled, -1.0, 1.0) audio_resampled = (audio_resampled * 32767).astype(np.int16) return audio_resampled def generate(texts): chunks = split_text(texts) for i, chunk in enumerate(chunks): try: text_processed = process_text(0, chunk, device) except Exception as e: logger.error(f"Error processing text: {e}") with torch.inference_mode(): output = model.synthesise( text_processed["x"], text_processed["x_lengths"], n_timesteps=10, temperature=0.667, spks=None, length_scale=speaking_rate ) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, denoiser_strength=0.00025) audio = output["waveform"].detach().cpu().squeeze().numpy() yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes() @app.post("/") @app.post("/tts") def predict(ssml: str = Body(...)): try: root = ET.fromstring(ssml) voice_element = root.find(".//voice") if voice_element is not None: transcription = voice_element.text.strip() language = voice_element.get(f'{xml_namespace}lang', "zh").strip() # voice_name = voice_element.get("name", "zh-f-soft-1").strip() else: return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) except ET.ParseError as e: return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) if not contains_words(transcription): audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes() return Response(audio, media_type='audio/wav') return StreamingResponse(generate(transcription), media_type='audio/wav') @app.get("/health") @app.get("/ready") async def ready(): return JSONResponse(status_code=200, content={"message": "success"}) @app.get("/health_check") async def health_check(): try: a = torch.ones(10, 20, dtype=torch.float32, device='cuda') b = torch.ones(20, 10, dtype=torch.float32, device='cuda') c = torch.matmul(a, b) if c.sum() == 10 * 20 * 10: return {"status": "ok"} else: raise HTTPException(status_code=503) except Exception as e: print(f'health_check failed') raise HTTPException(status_code=503) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=80)