import os model_dir = os.getenv("MODEL_DIR", "/mnt/models/") model_name = os.getenv("MODEL_NAME", "model.ckpt") import logging logging.basicConfig( format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO"), ) logger = logging.getLogger(__file__) # enable custom patcher if available patcher_path = os.path.join(model_dir, "custom_patcher.py") if os.path.exists(patcher_path): import shutil shutil.copyfile(patcher_path, "custom_patcher.py") try: import custom_patcher logger.info("Custom patcher has been applied.") except ImportError: logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.") else: logger.info("No custom_patcher found.") import wave import numpy as np from scipy.signal import resample import re from fastapi import FastAPI, Response, Body, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn import xml.etree.ElementTree as ET from dataclasses import dataclass import torch torch.set_num_threads(4) from piper_train.vits.lightning import VitsModel from piper_phonemize import ( phonemize_espeak, phoneme_ids_espeak, ) @dataclass class LanguageConfig: model: VitsModel espeak_id: str language_dict: dict[str, LanguageConfig] = {} # model = None device = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL_SR = os.getenv("MODEL_SR", 22050) TARGET_SR = 16000 N_ZEROS = 100 noise_scale, length_scale, noise_w = 0.667, 1.0, 0.8 def init(): # global model global language_dict ckpt_path = os.path.join(model_dir, model_name) # zh: # ckpt_path = os.path.join(model_dir, "zh/zh_CN/huayan/medium", 'epoch=3269-step=2460540.ckpt') model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device) model = model.eval() with torch.no_grad(): model.model_g.dec.remove_weight_norm() language_dict['zh'] = LanguageConfig(model=model, espeak_id='cmn') # # ar: # ckpt_path = os.path.join(model_dir, "ar/ar_JO/kareem/medium", 'epoch=5079-step=1682020.ckpt') # model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device) # model = model.eval() # with torch.no_grad(): # model.model_g.dec.remove_weight_norm() # language_dict['ar'] = LanguageConfig(model=model, espeak_id='ar') # # ru: # ckpt_path = os.path.join(model_dir, "ru/ru_RU/irina/medium", 'epoch=4139-step=929464.ckpt') # model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device) # model = model.eval() # with torch.no_grad(): # model.model_g.dec.remove_weight_norm() # language_dict['ru'] = LanguageConfig(model=model, espeak_id='ru') @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) xml_namespace = "{http://www.w3.org/XML/1998/namespace}" symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—' def contains_words(text): return any(char not in symbols for char in text) def split_text(text, max_chars=135): sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text) sentences = [s.strip() for s in sentences if s.strip()] chunks = [] current_chunk = "" for sentence in sentences: if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence if current_chunk: chunks.append(current_chunk.strip()) return chunks def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray: if ori_sr != target_sr: number_of_samples = int(len(audio) * float(target_sr) / ori_sr) audio_resampled = resample(audio, number_of_samples) else: audio_resampled = audio if audio.dtype == np.float32: audio_resampled = np.clip(audio_resampled, -1.0, 1.0) audio_resampled = (audio_resampled * 32767).astype(np.int16) return audio_resampled def generate(texts, language): chunks = split_text(texts) model = language_dict[language].model espeak_id = language_dict[language].espeak_id for i, chunk in enumerate(chunks): line = chunk.strip() if not line: continue all_phonemes = phonemize_espeak(line, espeak_id) phonemes = [ phoneme for sentence_phonemes in all_phonemes for phoneme in sentence_phonemes ] phoneme_ids = phoneme_ids_espeak(phonemes) text = torch.LongTensor(phoneme_ids).unsqueeze(0).to(device) text_lengths = torch.LongTensor([len(phoneme_ids)]).to(device) scales = [noise_scale, length_scale, noise_w] speaker_id = 0 sid = torch.LongTensor([speaker_id]).to(device) audio = model(text, text_lengths, scales, sid=sid).detach().cpu().squeeze().numpy() yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes() @app.post("/") @app.post("/tts") def predict(ssml: str = Body(...)): try: root = ET.fromstring(ssml) voice_element = root.find(".//voice") if voice_element is not None: transcription = voice_element.text.strip() language = voice_element.get(f'{xml_namespace}lang', '').strip() # voice_name = voice_element.get("name", "zh-f-soft-1").strip() else: return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) except ET.ParseError as e: return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) if language not in language_dict: return JSONResponse(status_code=400, content={"message": f"Language '{language}' is not supported."}) if not contains_words(transcription): audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes() return Response(audio, media_type='audio/wav') return StreamingResponse(generate(transcription, language), media_type='audio/wav') @app.get("/ready") @app.get("/health") async def ready(): return JSONResponse(status_code=200, content={"message": "success"}) @app.get("/health_check") async def health_check(): try: a = torch.ones(10, 20, dtype=torch.float32, device='cuda') b = torch.ones(20, 10, dtype=torch.float32, device='cuda') c = torch.matmul(a, b) if c.sum() == 10 * 20 * 10: return {"status": "ok"} else: raise HTTPException(status_code=503) except Exception as e: print(f'health_check failed') raise HTTPException(status_code=503) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=80)