Files
2025-08-20 21:28:01 +08:00

209 lines
6.9 KiB
Python

import os
model_dir = os.getenv("MODEL_DIR", "/mnt/models/")
model_name = os.getenv("MODEL_NAME", "model.ckpt")
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
# enable custom patcher if available
patcher_path = os.path.join(model_dir, "custom_patcher.py")
if os.path.exists(patcher_path):
import shutil
shutil.copyfile(patcher_path, "custom_patcher.py")
try:
import custom_patcher
logger.info("Custom patcher has been applied.")
except ImportError:
logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.")
else:
logger.info("No custom_patcher found.")
import wave
import numpy as np
from scipy.signal import resample
import re
from fastapi import FastAPI, Response, Body, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import xml.etree.ElementTree as ET
from dataclasses import dataclass
import torch
torch.set_num_threads(4)
from piper_train.vits.lightning import VitsModel
from piper_phonemize import (
phonemize_espeak,
phoneme_ids_espeak,
)
@dataclass
class LanguageConfig:
model: VitsModel
espeak_id: str
language_dict: dict[str, LanguageConfig] = {}
# model = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_SR = os.getenv("MODEL_SR", 22050)
TARGET_SR = 16000
N_ZEROS = 100
noise_scale, length_scale, noise_w = 0.667, 1.0, 0.8
def init():
# global model
global language_dict
ckpt_path = os.path.join(model_dir, model_name)
# zh:
# ckpt_path = os.path.join(model_dir, "zh/zh_CN/huayan/medium", 'epoch=3269-step=2460540.ckpt')
model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device)
model = model.eval()
with torch.no_grad():
model.model_g.dec.remove_weight_norm()
language_dict['zh'] = LanguageConfig(model=model, espeak_id='cmn')
# # ar:
# ckpt_path = os.path.join(model_dir, "ar/ar_JO/kareem/medium", 'epoch=5079-step=1682020.ckpt')
# model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device)
# model = model.eval()
# with torch.no_grad():
# model.model_g.dec.remove_weight_norm()
# language_dict['ar'] = LanguageConfig(model=model, espeak_id='ar')
# # ru:
# ckpt_path = os.path.join(model_dir, "ru/ru_RU/irina/medium", 'epoch=4139-step=929464.ckpt')
# model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device)
# model = model.eval()
# with torch.no_grad():
# model.model_g.dec.remove_weight_norm()
# language_dict['ru'] = LanguageConfig(model=model, espeak_id='ru')
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
def contains_words(text):
return any(char not in symbols for char in text)
def split_text(text, max_chars=135):
sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
if ori_sr != target_sr:
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
audio_resampled = resample(audio, number_of_samples)
else:
audio_resampled = audio
if audio.dtype == np.float32:
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
audio_resampled = (audio_resampled * 32767).astype(np.int16)
return audio_resampled
def generate(texts, language):
chunks = split_text(texts)
model = language_dict[language].model
espeak_id = language_dict[language].espeak_id
for i, chunk in enumerate(chunks):
line = chunk.strip()
if not line:
continue
all_phonemes = phonemize_espeak(line, espeak_id)
phonemes = [
phoneme
for sentence_phonemes in all_phonemes
for phoneme in sentence_phonemes
]
phoneme_ids = phoneme_ids_espeak(phonemes)
text = torch.LongTensor(phoneme_ids).unsqueeze(0).to(device)
text_lengths = torch.LongTensor([len(phoneme_ids)]).to(device)
scales = [noise_scale, length_scale, noise_w]
speaker_id = 0
sid = torch.LongTensor([speaker_id]).to(device)
audio = model(text, text_lengths, scales, sid=sid).detach().cpu().squeeze().numpy()
yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes()
@app.post("/")
@app.post("/tts")
def predict(ssml: str = Body(...)):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
transcription = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', '').strip()
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
if language not in language_dict:
return JSONResponse(status_code=400, content={"message": f"Language '{language}' is not supported."})
if not contains_words(transcription):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
return Response(audio, media_type='audio/wav')
return StreamingResponse(generate(transcription, language), media_type='audio/wav')
@app.get("/ready")
@app.get("/health")
async def ready():
return JSONResponse(status_code=200, content={"message": "success"})
@app.get("/health_check")
async def health_check():
try:
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
c = torch.matmul(a, b)
if c.sum() == 10 * 20 * 10:
return {"status": "ok"}
else:
raise HTTPException(status_code=503)
except Exception as e:
print(f'health_check failed')
raise HTTPException(status_code=503)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=80)