209 lines
6.9 KiB
Python
209 lines
6.9 KiB
Python
import os
|
|
model_dir = os.getenv("MODEL_DIR", "/mnt/models/")
|
|
model_name = os.getenv("MODEL_NAME", "model.ckpt")
|
|
|
|
import logging
|
|
logging.basicConfig(
|
|
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=os.environ.get("LOGLEVEL", "INFO"),
|
|
)
|
|
logger = logging.getLogger(__file__)
|
|
|
|
# enable custom patcher if available
|
|
patcher_path = os.path.join(model_dir, "custom_patcher.py")
|
|
if os.path.exists(patcher_path):
|
|
import shutil
|
|
shutil.copyfile(patcher_path, "custom_patcher.py")
|
|
try:
|
|
import custom_patcher
|
|
logger.info("Custom patcher has been applied.")
|
|
except ImportError:
|
|
logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.")
|
|
else:
|
|
logger.info("No custom_patcher found.")
|
|
|
|
import wave
|
|
import numpy as np
|
|
from scipy.signal import resample
|
|
import re
|
|
|
|
from fastapi import FastAPI, Response, Body, HTTPException
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from contextlib import asynccontextmanager
|
|
import uvicorn
|
|
import xml.etree.ElementTree as ET
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
torch.set_num_threads(4)
|
|
|
|
from piper_train.vits.lightning import VitsModel
|
|
from piper_phonemize import (
|
|
phonemize_espeak,
|
|
phoneme_ids_espeak,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class LanguageConfig:
|
|
model: VitsModel
|
|
espeak_id: str
|
|
|
|
language_dict: dict[str, LanguageConfig] = {}
|
|
|
|
# model = None
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
MODEL_SR = os.getenv("MODEL_SR", 22050)
|
|
TARGET_SR = 16000
|
|
N_ZEROS = 100
|
|
noise_scale, length_scale, noise_w = 0.667, 1.0, 0.8
|
|
|
|
def init():
|
|
# global model
|
|
global language_dict
|
|
|
|
ckpt_path = os.path.join(model_dir, model_name)
|
|
# zh:
|
|
# ckpt_path = os.path.join(model_dir, "zh/zh_CN/huayan/medium", 'epoch=3269-step=2460540.ckpt')
|
|
model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device)
|
|
model = model.eval()
|
|
with torch.no_grad():
|
|
model.model_g.dec.remove_weight_norm()
|
|
language_dict['zh'] = LanguageConfig(model=model, espeak_id='cmn')
|
|
|
|
# # ar:
|
|
# ckpt_path = os.path.join(model_dir, "ar/ar_JO/kareem/medium", 'epoch=5079-step=1682020.ckpt')
|
|
# model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device)
|
|
# model = model.eval()
|
|
# with torch.no_grad():
|
|
# model.model_g.dec.remove_weight_norm()
|
|
# language_dict['ar'] = LanguageConfig(model=model, espeak_id='ar')
|
|
|
|
# # ru:
|
|
# ckpt_path = os.path.join(model_dir, "ru/ru_RU/irina/medium", 'epoch=4139-step=929464.ckpt')
|
|
# model = VitsModel.load_from_checkpoint(ckpt_path, dataset=None).to(device)
|
|
# model = model.eval()
|
|
# with torch.no_grad():
|
|
# model.model_g.dec.remove_weight_norm()
|
|
# language_dict['ru'] = LanguageConfig(model=model, espeak_id='ru')
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
init()
|
|
yield
|
|
pass
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
|
|
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
|
|
def contains_words(text):
|
|
return any(char not in symbols for char in text)
|
|
|
|
|
|
def split_text(text, max_chars=135):
|
|
sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text)
|
|
sentences = [s.strip() for s in sentences if s.strip()]
|
|
|
|
chunks = []
|
|
current_chunk = ""
|
|
for sentence in sentences:
|
|
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
|
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
return chunks
|
|
|
|
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
|
|
if ori_sr != target_sr:
|
|
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
|
|
audio_resampled = resample(audio, number_of_samples)
|
|
else:
|
|
audio_resampled = audio
|
|
if audio.dtype == np.float32:
|
|
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
|
|
audio_resampled = (audio_resampled * 32767).astype(np.int16)
|
|
return audio_resampled
|
|
|
|
def generate(texts, language):
|
|
chunks = split_text(texts)
|
|
model = language_dict[language].model
|
|
espeak_id = language_dict[language].espeak_id
|
|
for i, chunk in enumerate(chunks):
|
|
line = chunk.strip()
|
|
if not line:
|
|
continue
|
|
all_phonemes = phonemize_espeak(line, espeak_id)
|
|
phonemes = [
|
|
phoneme
|
|
for sentence_phonemes in all_phonemes
|
|
for phoneme in sentence_phonemes
|
|
]
|
|
phoneme_ids = phoneme_ids_espeak(phonemes)
|
|
|
|
text = torch.LongTensor(phoneme_ids).unsqueeze(0).to(device)
|
|
text_lengths = torch.LongTensor([len(phoneme_ids)]).to(device)
|
|
scales = [noise_scale, length_scale, noise_w]
|
|
speaker_id = 0
|
|
sid = torch.LongTensor([speaker_id]).to(device)
|
|
audio = model(text, text_lengths, scales, sid=sid).detach().cpu().squeeze().numpy()
|
|
yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes()
|
|
|
|
|
|
@app.post("/")
|
|
@app.post("/tts")
|
|
def predict(ssml: str = Body(...)):
|
|
try:
|
|
root = ET.fromstring(ssml)
|
|
voice_element = root.find(".//voice")
|
|
if voice_element is not None:
|
|
transcription = voice_element.text.strip()
|
|
language = voice_element.get(f'{xml_namespace}lang', '').strip()
|
|
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
|
|
else:
|
|
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
|
|
except ET.ParseError as e:
|
|
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
|
|
|
|
if language not in language_dict:
|
|
return JSONResponse(status_code=400, content={"message": f"Language '{language}' is not supported."})
|
|
|
|
if not contains_words(transcription):
|
|
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
|
|
return Response(audio, media_type='audio/wav')
|
|
|
|
return StreamingResponse(generate(transcription, language), media_type='audio/wav')
|
|
|
|
|
|
@app.get("/ready")
|
|
@app.get("/health")
|
|
async def ready():
|
|
return JSONResponse(status_code=200, content={"message": "success"})
|
|
|
|
@app.get("/health_check")
|
|
async def health_check():
|
|
try:
|
|
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
|
|
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
|
|
c = torch.matmul(a, b)
|
|
if c.sum() == 10 * 20 * 10:
|
|
return {"status": "ok"}
|
|
else:
|
|
raise HTTPException(status_code=503)
|
|
except Exception as e:
|
|
print(f'health_check failed')
|
|
raise HTTPException(status_code=503)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=80)
|