Files
enginex-bi_series-tts/bi_v100-kokoro/kokoro_server.py
2025-08-14 10:02:15 +08:00

133 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from fastapi import FastAPI, Body
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import xml.etree.ElementTree as ET
from kokoro import KPipeline, KModel
import numpy as np
# from scipy.signal import resample
import torch
from torch import Tensor
from torch.nn import functional as F
from typing import Optional, List
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
with torch.amp.autocast('cuda', dtype=torch.float16):
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation).float()
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
repo_id = 'hexgrad/Kokoro-82M-v1.1-zh'
# MODEL_SR = 24000
model = None
en_empty_pipeline = None
en_pipeline = None
zh_pipeline = None
en_voice_pt = None
zh_voice_pt = None
en_voice = os.getenv('EN_VOICE', 'af_maple.pt')
zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt')
model_dir = os.getenv('MODEL_DIR', '/models/hexgrad/Kokoro-82M-v1.1-zh')
def en_callable(text):
if text == 'Kokoro':
return 'kˈOkəɹO'
elif text == 'Sol':
return 'sˈOl'
return next(en_empty_pipeline(text)).phonemes
# HACK: Mitigate rushing caused by lack of training data beyond ~100 tokens
# Simple piecewise linear fn that decreases speed as len_ps increases
def speed_callable(len_ps):
speed = 0.8
if len_ps <= 83:
speed = 1
elif len_ps < 183:
speed = 1 - (len_ps - 83) / 500
return speed
# def resample_audio(data: np.ndarray, original_rate: int, target_rate: int):
# ori_dtype = data.dtype
# # data = normalize_audio(data)
# number_of_samples = int(len(data) * float(target_rate) / original_rate)
# resampled_data = resample(data, number_of_samples)
# # resampled_data = normalize_audio(resampled_data)
# return resampled_data.astype(ori_dtype)
def audio_postprocess(audio: np.ndarray):
if audio.dtype == np.float32:
audio = np.int16(audio * 32767)
return audio
def init():
global model, en_empty_pipeline, en_pipeline, zh_pipeline
global en_voice_pt, zh_voice_pt
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = KModel(repo_id=repo_id, model=os.path.join(model_dir, 'kokoro-v1_1-zh.pth'), config=os.path.join(model_dir, 'config.json')).to(device).eval()
en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False)
en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)
zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable)
en_voice_pt = os.path.join(model_dir, 'voices', en_voice)
zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice)
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
# return 24kHz pcm-16
@app.post("/tts")
def generate(ssml: str = Body(...)):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
text = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
def streaming_generator():
if language == 'en':
generator = en_pipeline(text=text, voice=en_voice_pt)
else:
generator = zh_pipeline(text=text, voice=zh_voice_pt, speed=speed_callable)
for (_, _, audio) in generator:
yield audio_postprocess(audio.numpy()).tobytes()
return StreamingResponse(streaming_generator(), media_type='audio/wav')
@app.get("/health")
@app.get("/ready")
async def ready():
return JSONResponse(status_code=200, content={"status": "ok"})
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=80)