Files
enginex-bi_series-tts/bi_v100-matcha/matcha_server.py
2025-08-20 21:11:23 +08:00

198 lines
6.9 KiB
Python

import os
model_dir = os.getenv("MODEL_DIR", "/mounted_model")
model_name = os.getenv("MODEL_NAME", "model.ckpt")
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
# enable custom patcher if available
patcher_path = os.path.join(model_dir, "custom_patcher.py")
if os.path.exists(patcher_path):
import shutil
shutil.copyfile(patcher_path, "custom_patcher.py")
try:
import custom_patcher
logger.info("Custom patcher has been applied.")
except ImportError:
logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.")
else:
logger.info("No custom_patcher found.")
import wave
import numpy as np
from scipy.signal import resample
import re
from fastapi import FastAPI, Response, Body, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import xml.etree.ElementTree as ET
import torch
torch.set_num_threads(4)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
from torch import Tensor
from torch.nn import functional as F
from typing import Optional, List
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
with torch.amp.autocast('cuda', dtype=torch.float16):
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation).float()
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
from matcha.cli import load_matcha, load_vocoder, to_waveform, process_text
model = None
vocoder = None
denoiser = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_SR = int(os.getenv("MODEL_SR", 22050))
speaking_rate = float(os.getenv("SPEAKING_RATE", 1.0))
TARGET_SR = 16000
N_ZEROS = 100
def init():
global model, vocoder, denoiser
ckpt_path = os.path.join(model_dir, model_name)
vocoder_path = os.path.join(model_dir, "generator_v1")
model = load_matcha("custom_model", ckpt_path, device)
vocoder, denoiser = load_vocoder("hifigan_T2_v1", vocoder_path, device)
# warmup:
for _ in generate("你好,欢迎使用语音合成服务。"):
pass
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
def contains_words(text):
return any(char not in symbols for char in text)
def split_text(text, max_chars=135):
sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text)
sentences = [s.strip() for s in sentences if s.strip()]
return sentences
# chunks = []
# current_chunk = ""
# for sentence in sentences:
# if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
# current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
# else:
# if current_chunk:
# chunks.append(current_chunk.strip())
# current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
# if current_chunk:
# chunks.append(current_chunk.strip())
# return chunks
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
if ori_sr != target_sr:
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
audio_resampled = resample(audio, number_of_samples)
else:
audio_resampled = audio
if audio.dtype == np.float32:
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
audio_resampled = (audio_resampled * 32767).astype(np.int16)
return audio_resampled
def generate(texts):
chunks = split_text(texts)
for i, chunk in enumerate(chunks):
try:
text_processed = process_text(0, chunk, device)
except Exception as e:
logger.error(f"Error processing text: {e}")
with torch.inference_mode():
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=10,
temperature=0.667,
spks=None,
length_scale=speaking_rate
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, denoiser_strength=0.00025)
audio = output["waveform"].detach().cpu().squeeze().numpy()
yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes()
@app.post("/")
@app.post("/tts")
def predict(ssml: str = Body(...)):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
transcription = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
if not contains_words(transcription):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
return Response(audio, media_type='audio/wav')
return StreamingResponse(generate(transcription), media_type='audio/wav')
@app.get("/health")
@app.get("/ready")
async def ready():
return JSONResponse(status_code=200, content={"message": "success"})
@app.get("/health_check")
async def health_check():
try:
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
c = torch.matmul(a, b)
if c.sum() == 10 * 20 * 10:
return {"status": "ok"}
else:
raise HTTPException(status_code=503)
except Exception as e:
print(f'health_check failed')
raise HTTPException(status_code=503)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=80)