Files
2025-09-05 11:27:43 +08:00

183 lines
5.6 KiB
Python

import os
model_dir = os.getenv("MODEL_DIR", "/mounted_model")
model_name = os.getenv("MODEL_NAME", "model.ckpt")
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
import wave
import numpy as np
from scipy.signal import resample
import re
from fastapi import FastAPI, Response, Body, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import xml.etree.ElementTree as ET
import torch
torch.set_default_dtype(torch.float32)
_original_hann_window = torch.hann_window
def _safe_hann_window(window_length,
periodic=True,
*,
dtype=None,
layout=torch.strided,
device=None,
requires_grad=False,
**kwargs):
"""
NPU 不支持int64 hann_window, 替换实现
"""
if dtype is None:
dtype = torch.float32
# 总是在 CPU 先生成,绕过 NPU 上的 in-place cos 实现
win = _original_hann_window(
window_length,
periodic=periodic,
dtype=dtype,
layout=layout,
device="cpu",
requires_grad=requires_grad,
**kwargs,
)
if device is not None:
win = win.to(device)
return win
torch.hann_window = _safe_hann_window
from torch import Tensor
from torch.nn import functional as F
from typing import Optional, List
from matcha.cli import load_matcha, load_vocoder, to_waveform, process_text
model = None
vocoder = None
denoiser = None
device = 'npu'
MODEL_SR = int(os.getenv("MODEL_SR", 22050))
speaking_rate = float(os.getenv("SPEAKING_RATE", 1.0))
TARGET_SR = 16000
N_ZEROS = 100
def init():
global model, vocoder, denoiser
ckpt_path = os.path.join(model_dir, model_name)
vocoder_path = os.path.join(model_dir, "generator_v1")
model = load_matcha("custom_model", ckpt_path, device)
vocoder, denoiser = load_vocoder("hifigan_T2_v1", vocoder_path, device)
# warmup:
for _ in generate("你好,欢迎使用语音合成服务。"):
pass
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
def contains_words(text):
return any(char not in symbols for char in text)
def split_text(text, max_chars=135):
sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text)
sentences = [s.strip() for s in sentences if s.strip()]
return sentences
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
if ori_sr != target_sr:
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
audio_resampled = resample(audio, number_of_samples)
else:
audio_resampled = audio
if audio.dtype == np.float32:
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
audio_resampled = (audio_resampled * 32767).astype(np.int16)
return audio_resampled
def generate(texts):
chunks = split_text(texts)
for i, chunk in enumerate(chunks):
try:
text_processed = process_text(0, chunk, device)
except Exception as e:
logger.error(f"Error processing text: {e}")
with torch.inference_mode():
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=10,
temperature=0.667,
spks=None,
length_scale=speaking_rate
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, denoiser_strength=0.00025)
audio = output["waveform"].detach().cpu().squeeze().numpy()
yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes()
@app.post("/")
@app.post("/tts")
def predict(ssml: str = Body(...)):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
transcription = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
if not contains_words(transcription):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
return Response(audio, media_type='audio/wav')
return StreamingResponse(generate(transcription), media_type='audio/wav')
@app.get("/health")
@app.get("/ready")
async def ready():
return JSONResponse(status_code=200, content={"message": "success"})
@app.get("/health_check")
async def health_check():
try:
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
c = torch.matmul(a, b)
if c.sum() == 10 * 20 * 10:
return {"status": "ok"}
else:
raise HTTPException(status_code=503)
except Exception as e:
print(f'health_check failed')
raise HTTPException(status_code=503)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=80)