import os import io from fastapi import FastAPI, Response, Body, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn import xml.etree.ElementTree as ET from kokoro import KPipeline, KModel # import soundfile as sf import wave import numpy as np from scipy.signal import resample import torch from torch import Tensor from torch.nn import functional as F from typing import Optional, List import re from dataclasses import dataclass import kokoro.istftnet as _ist def _inverse_no_complex(self, magnitude, phase): """ 解决 MLU 设备上不支持复数计算 """ device = magnitude.device dtype = magnitude.dtype win_dev = torch.hann_window(self.win_length, device=device, dtype=dtype) real = magnitude * torch.cos(phase) imag = magnitude * torch.sin(phase) spec_ri = torch.stack([real, imag], dim=-1).contiguous() # (..., 2) real_cpu = real.to("cpu") imag_cpu = imag.to("cpu") spec_complex_cpu = torch.complex(real_cpu, imag_cpu) # (..,) 复数张量 win_cpu = torch.hann_window(self.win_length, device="cpu", dtype=dtype) wav_cpu = torch.istft( spec_complex_cpu, n_fft=self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=win_cpu, center=True, normalized=False, onesided=True, ) return wav_cpu.to(device).unsqueeze(-2) def _transform_no_complex(self, input_data): """ 纯实数 STFT:return_complex=False,随后手动求幅度与相位 """ z = torch.stft( input_data, n_fft=self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.window.to(input_data.device, dtype=input_data.dtype), return_complex=False, center=True, normalized=False, ) real = z[..., 0] imag = z[..., 1] magnitude = torch.sqrt(real * real + imag * imag) phase = torch.atan2(imag, real) return magnitude, phase # 替换 Kokoro 的 STFT.inverse 实现 _ist.TorchSTFT.inverse = _inverse_no_complex _ist.TorchSTFT.transform = _transform_no_complex repo_id = 'hexgrad/Kokoro-82M-v1.1-zh' MODEL_SR = 24000 TARGET_SR = 16000 # How much silence to insert between paragraphs: 5000 is about 0.2 seconds N_ZEROS = 20 model = None en_empty_pipeline = None en_voice = os.getenv('EN_VOICE', 'af_maple.pt') zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt') model_dir = os.getenv('MODEL_DIR', '/model/hexgrad') model_name = os.getenv('MODEL_NAME','kokoro-v1_1-zh.pth') # model_1_1_dir = os.path.join(model_dir, 'Kokoro-82M-v1.1-zh') # model_1_0_dir = os.path.join(model_dir, 'Kokoro-82M') # repo_id_1_0 = 'hexgrad/Kokoro-82M' @dataclass class LanguagePipeline: pipeline: KPipeline voice_pt: str pipeline_dict: dict[str, LanguagePipeline] = {} def en_callable(text): if text == 'Kokoro': return 'kˈOkəɹO' elif text == 'Sol': return 'sˈOl' return next(en_empty_pipeline(text)).phonemes # HACK: Mitigate rushing caused by lack of training data beyond ~100 tokens # Simple piecewise linear fn that decreases speed as len_ps increases def speed_callable(len_ps): speed = 0.8 if len_ps <= 83: speed = 1 elif len_ps < 183: speed = 1 - (len_ps - 83) / 500 return speed # from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): # This will create a wave header then append the frame input # It should be first on a streaming wav file # Other frames better should not have it (else you will hear some artifacts each chunk start) wav_buf = io.BytesIO() with wave.open(wav_buf, "wb") as vfout: vfout.setnchannels(channels) vfout.setsampwidth(sample_width) vfout.setframerate(sample_rate) vfout.writeframes(frame_input) wav_buf.seek(0) return wav_buf.read() def resample_audio(data: np.ndarray, original_rate: int, target_rate: int): ori_dtype = data.dtype # data = normalize_audio(data) number_of_samples = int(len(data) * float(target_rate) / original_rate) resampled_data = resample(data, number_of_samples) # resampled_data = normalize_audio(resampled_data) return resampled_data.astype(ori_dtype) def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int): audio = resample_audio(data, original_rate, target_rate) if audio.dtype == np.float32: audio = np.int16(audio * 32767) audio = np.concatenate([audio, np.zeros(N_ZEROS, dtype=np.int16)]) return audio def init(): global model, en_empty_pipeline global model_1_0 global pipeline_dict # device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'mlu' model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval() en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False) en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model) zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable) en_voice_pt = os.path.join(model_dir, 'voices', en_voice) zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice) pipeline_dict['zh'] = LanguagePipeline(pipeline=zh_pipeline, voice_pt=zh_voice_pt) pipeline_dict['en'] = LanguagePipeline(pipeline=en_pipeline, voice_pt=en_voice_pt) # v1.0 model for other languages # model_1_0 = KModel(repo_id=repo_id_1_0, model=os.path.join(model_1_0_dir, 'kokoro-v1_0.pth'), config=os.path.join(model_1_0_dir, 'config.json')).to(device).eval() # # es # es_pipeline = KPipeline(lang_code='e', repo_id=repo_id_1_0, model=model_1_0) # es_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ef_dora.pt') # pipeline_dict['es'] = LanguagePipeline(pipeline=es_pipeline, voice_pt=es_voice_pt) # # fr # fr_pipeline = KPipeline(lang_code='f', repo_id=repo_id_1_0, model=model_1_0) # fr_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ff_siwis.pt') # pipeline_dict['fr'] = LanguagePipeline(pipeline=fr_pipeline, voice_pt=fr_voice_pt) # # hi # hi_pipeline = KPipeline(lang_code='h', repo_id=repo_id_1_0, model=model_1_0) # hi_voice_pt = os.path.join(model_1_0_dir, 'voices', 'hf_alpha.pt') # pipeline_dict['hi'] = LanguagePipeline(pipeline=hi_pipeline, voice_pt=hi_voice_pt) # # it # it_pipeline = KPipeline(lang_code='i', repo_id=repo_id_1_0, model=model_1_0) # it_voice_pt = os.path.join(model_1_0_dir, 'voices', 'if_sara.pt') # pipeline_dict['it'] = LanguagePipeline(pipeline=it_pipeline, voice_pt=it_voice_pt) # # ja # ja_pipeline = KPipeline(lang_code='j', repo_id=repo_id_1_0, model=model_1_0) # ja_voice_pt = os.path.join(model_1_0_dir, 'voices', 'jf_alpha.pt') # pipeline_dict['ja'] = LanguagePipeline(pipeline=ja_pipeline, voice_pt=ja_voice_pt) # # pt # pt_pipeline = KPipeline(lang_code='p', repo_id=repo_id_1_0, model=model_1_0) # pt_voice_pt = os.path.join(model_1_0_dir, 'voices', 'pf_dora.pt') # pipeline_dict['pt'] = LanguagePipeline(pipeline=pt_pipeline, voice_pt=pt_voice_pt) warmup() @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) def warmup(): zh_pipeline = pipeline_dict['zh'].pipeline voice = pipeline_dict['zh'].voice_pt generator = zh_pipeline(text="语音合成测试TTS。", voice=voice, speed=speed_callable) for _ in generator: pass xml_namespace = "{http://www.w3.org/XML/1998/namespace}" symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—' def contains_words(text): return any(char not in symbols for char in text) def cut_sentences(text) -> list[str]: text = text.strip() splits = re.split(r"([.;?!、。?!;])", text) sentences = [] for i in range(0, len(splits), 2): if i + 1 < len(splits): s = splits[i] + splits[i + 1] else: s = splits[i] s = s.strip() if s: sentences.append(s) return sentences LANGUAGE_ALIASES = { 'z': 'zh', 'a': 'en', 'e': 'es', 'f': 'fr', 'h': 'hi', 'i': 'it', 'j': 'ja', 'p': 'pt', } @app.post("/") @app.post("/tts") def predict(ssml: str = Body(...), include_header: bool = False): try: root = ET.fromstring(ssml) voice_element = root.find(".//voice") if voice_element is not None: transcription = voice_element.text.strip() language = voice_element.get(f'{xml_namespace}lang', "zh").strip() # voice_name = voice_element.get("name", "zh-f-soft-1").strip() else: return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) except ET.ParseError as e: return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) if not contains_words(transcription): audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes() if include_header: audio_header = wave_header_chunk(sample_rate=TARGET_SR) audio = audio_header + audio return Response(audio, media_type='audio/wav') if language not in pipeline_dict: if language in LANGUAGE_ALIASES: language = LANGUAGE_ALIASES[language] else: return JSONResponse(status_code=400, content={"message": f"Language '{language}' not supported."}) def streaming_generator(): texts = cut_sentences(transcription) has_yield = False for text in texts: if text.strip() and contains_words(text): pipeline = pipeline_dict[language].pipeline voice = pipeline_dict[language].voice_pt if language == 'zh': generator = pipeline(text=text, voice=voice, speed=speed_callable) else: generator = pipeline(text=text, voice=voice) for (_, _, audio) in generator: if include_header and not has_yield: has_yield = True yield wave_header_chunk(sample_rate=TARGET_SR) yield audio_postprocess(audio.numpy(), MODEL_SR, TARGET_SR).tobytes() return StreamingResponse(streaming_generator(), media_type='audio/wav') @app.get("/health") @app.get("/ready") async def ready(): return JSONResponse(status_code=200, content={"status": "ok"}) @app.get("/health_check") async def health_check(): try: a = torch.ones(10, 20, dtype=torch.float32, device='cuda') b = torch.ones(20, 10, dtype=torch.float32, device='cuda') c = torch.matmul(a, b) if c.sum() == 10 * 20 * 10: return {"status": "ok"} else: raise HTTPException(status_code=503) except Exception as e: print(f'health_check failed') raise HTTPException(status_code=503) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=80)