import os from fastapi import FastAPI, Body from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn import xml.etree.ElementTree as ET from kokoro import KPipeline, KModel import numpy as np # from scipy.signal import resample import torch from torch import Tensor from torch.nn import functional as F from typing import Optional, List def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. num_spatial_dims = 1 output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] num_spatial_dims, self.dilation) # type: ignore[arg-type] with torch.amp.autocast('cuda', dtype=torch.float16): return F.conv_transpose1d( input, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation).float() torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward repo_id = 'hexgrad/Kokoro-82M-v1.1-zh' # MODEL_SR = 24000 model = None en_empty_pipeline = None en_pipeline = None zh_pipeline = None en_voice_pt = None zh_voice_pt = None en_voice = os.getenv('EN_VOICE', 'af_maple.pt') zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt') model_dir = os.getenv('MODEL_DIR', '/models/hexgrad/Kokoro-82M-v1.1-zh') def en_callable(text): if text == 'Kokoro': return 'kˈOkəɹO' elif text == 'Sol': return 'sˈOl' return next(en_empty_pipeline(text)).phonemes # HACK: Mitigate rushing caused by lack of training data beyond ~100 tokens # Simple piecewise linear fn that decreases speed as len_ps increases def speed_callable(len_ps): speed = 0.8 if len_ps <= 83: speed = 1 elif len_ps < 183: speed = 1 - (len_ps - 83) / 500 return speed # def resample_audio(data: np.ndarray, original_rate: int, target_rate: int): # ori_dtype = data.dtype # # data = normalize_audio(data) # number_of_samples = int(len(data) * float(target_rate) / original_rate) # resampled_data = resample(data, number_of_samples) # # resampled_data = normalize_audio(resampled_data) # return resampled_data.astype(ori_dtype) def audio_postprocess(audio: np.ndarray): if audio.dtype == np.float32: audio = np.int16(audio * 32767) return audio def init(): global model, en_empty_pipeline, en_pipeline, zh_pipeline global en_voice_pt, zh_voice_pt device = 'cuda' if torch.cuda.is_available() else 'cpu' model = KModel(repo_id=repo_id, model=os.path.join(model_dir, 'kokoro-v1_1-zh.pth'), config=os.path.join(model_dir, 'config.json')).to(device).eval() en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False) en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model) zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable) en_voice_pt = os.path.join(model_dir, 'voices', en_voice) zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice) @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) xml_namespace = "{http://www.w3.org/XML/1998/namespace}" # return 24kHz pcm-16 @app.post("/tts") def generate(ssml: str = Body(...)): try: root = ET.fromstring(ssml) voice_element = root.find(".//voice") if voice_element is not None: text = voice_element.text.strip() language = voice_element.get(f'{xml_namespace}lang', "zh").strip() else: return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) except ET.ParseError as e: return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) def streaming_generator(): if language == 'en': generator = en_pipeline(text=text, voice=en_voice_pt) else: generator = zh_pipeline(text=text, voice=zh_voice_pt, speed=speed_callable) for (_, _, audio) in generator: yield audio_postprocess(audio.numpy()).tobytes() return StreamingResponse(streaming_generator(), media_type='audio/wav') @app.get("/health") @app.get("/ready") async def ready(): return JSONResponse(status_code=200, content={"status": "ok"}) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=80)