import os model_dir = os.getenv("MODEL_DIR", "/mounted_model") model_name = os.getenv("MODEL_NAME", "model.ckpt") import logging logging.basicConfig( format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO"), ) logger = logging.getLogger(__file__) # enable custom patcher if available patcher_path = os.path.join(model_dir, "custom_patcher.py") if os.path.exists(patcher_path): import shutil shutil.copyfile(patcher_path, "custom_patcher.py") try: import custom_patcher logger.info("Custom patcher has been applied.") except ImportError: logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.") else: logger.info("No custom_patcher found.") import wave import numpy as np from scipy.signal import resample import re from fastapi import FastAPI, Response, Body, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn import xml.etree.ElementTree as ET import torch torch.set_num_threads(4) # torch.backends.cuda.enable_flash_sdp(False) # torch.backends.cuda.enable_mem_efficient_sdp(False) # torch.backends.cuda.enable_math_sdp(True) from torch import Tensor from torch.nn import functional as F from typing import Optional, List # def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: # if self.padding_mode != 'zeros': # raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') # assert isinstance(self.padding, tuple) # # One cannot replace List by Tuple or Sequence in "_output_padding" because # # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. # num_spatial_dims = 1 # output_padding = self._output_padding( # input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] # num_spatial_dims, self.dilation) # type: ignore[arg-type] # with torch.amp.autocast('cuda', dtype=torch.float16): # return F.conv_transpose1d( # input, self.weight, self.bias, self.stride, self.padding, # output_padding, self.groups, self.dilation).float() # torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward from matcha.cli import load_matcha, load_vocoder, to_waveform, process_text model = None vocoder = None denoiser = None device = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL_SR = int(os.getenv("MODEL_SR", 22050)) speaking_rate = float(os.getenv("SPEAKING_RATE", 1.0)) TARGET_SR = 16000 N_ZEROS = 100 def init(): global model, vocoder, denoiser ckpt_path = os.path.join(model_dir, model_name) vocoder_path = os.path.join(model_dir, "generator_v1") model = load_matcha("custom_model", ckpt_path, device) vocoder, denoiser = load_vocoder("hifigan_T2_v1", vocoder_path, device) # warmup: for _ in generate("你好,欢迎使用语音合成服务。"): pass @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) xml_namespace = "{http://www.w3.org/XML/1998/namespace}" symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—' def contains_words(text): return any(char not in symbols for char in text) def split_text(text, max_chars=135): sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text) sentences = [s.strip() for s in sentences if s.strip()] return sentences # chunks = [] # current_chunk = "" # for sentence in sentences: # if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: # current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence # else: # if current_chunk: # chunks.append(current_chunk.strip()) # current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence # if current_chunk: # chunks.append(current_chunk.strip()) # return chunks def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray: if ori_sr != target_sr: number_of_samples = int(len(audio) * float(target_sr) / ori_sr) audio_resampled = resample(audio, number_of_samples) else: audio_resampled = audio if audio.dtype == np.float32: audio_resampled = np.clip(audio_resampled, -1.0, 1.0) audio_resampled = (audio_resampled * 32767).astype(np.int16) return audio_resampled def generate(texts): chunks = split_text(texts) for i, chunk in enumerate(chunks): try: text_processed = process_text(0, chunk, device) except Exception as e: logger.error(f"Error processing text: {e}") with torch.inference_mode(): output = model.synthesise( text_processed["x"], text_processed["x_lengths"], n_timesteps=10, temperature=0.667, spks=None, length_scale=speaking_rate ) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, denoiser_strength=0.00025) audio = output["waveform"].detach().cpu().squeeze().numpy() yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes() @app.post("/") @app.post("/tts") def predict(ssml: str = Body(...)): try: root = ET.fromstring(ssml) voice_element = root.find(".//voice") if voice_element is not None: transcription = voice_element.text.strip() language = voice_element.get(f'{xml_namespace}lang', "zh").strip() # voice_name = voice_element.get("name", "zh-f-soft-1").strip() else: return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) except ET.ParseError as e: return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) if not contains_words(transcription): audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes() return Response(audio, media_type='audio/wav') return StreamingResponse(generate(transcription), media_type='audio/wav') @app.get("/health") @app.get("/ready") async def ready(): return JSONResponse(status_code=200, content={"message": "success"}) @app.get("/health_check") async def health_check(): try: a = torch.ones(10, 20, dtype=torch.float32, device='cuda') b = torch.ones(20, 10, dtype=torch.float32, device='cuda') c = torch.matmul(a, b) if c.sum() == 10 * 20 * 10: return {"status": "ok"} else: raise HTTPException(status_code=503) except Exception as e: print(f'health_check failed') raise HTTPException(status_code=503) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=80)