diff --git a/bi_v100-kokoro/Dockerfile_kokoro b/bi_v100-kokoro/Dockerfile_kokoro index a120899..35b85a6 100644 --- a/bi_v100-kokoro/Dockerfile_kokoro +++ b/bi_v100-kokoro/Dockerfile_kokoro @@ -6,4 +6,7 @@ RUN pip install -r requirements_kokoro.txt -c constraints_kokoro.txt \ && apt update \ && apt install -y espeak-ng +COPY ./en_core_web_sm-3.8.0.tar.gz . +RUN pip install --no-index en_core_web_sm-3.8.0.tar.gz + ENTRYPOINT ["/bin/bash", "launch_kokoro.sh"] diff --git a/bi_v100-kokoro/constraints_kokoro.txt b/bi_v100-kokoro/constraints_kokoro.txt index 65d66ac..03d1dda 100644 --- a/bi_v100-kokoro/constraints_kokoro.txt +++ b/bi_v100-kokoro/constraints_kokoro.txt @@ -1 +1,3 @@ -torch==2.1.0+corex.3.2.1 \ No newline at end of file +torch==2.1.0+corex.3.2.1 +numpy==1.23.5 +scipy==1.14.1 diff --git a/bi_v100-kokoro/en_core_web_sm-3.8.0.tar.gz b/bi_v100-kokoro/en_core_web_sm-3.8.0.tar.gz new file mode 100644 index 0000000..1069d5b Binary files /dev/null and b/bi_v100-kokoro/en_core_web_sm-3.8.0.tar.gz differ diff --git a/bi_v100-kokoro/kokoro_server.py b/bi_v100-kokoro/kokoro_server.py index 8602998..633f535 100644 --- a/bi_v100-kokoro/kokoro_server.py +++ b/bi_v100-kokoro/kokoro_server.py @@ -1,19 +1,25 @@ import os +import io -from fastapi import FastAPI, Body +from fastapi import FastAPI, Response, Body, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn import xml.etree.ElementTree as ET from kokoro import KPipeline, KModel +# import soundfile as sf +import wave import numpy as np -# from scipy.signal import resample +from scipy.signal import resample import torch from torch import Tensor from torch.nn import functional as F from typing import Optional, List +import re +from dataclasses import dataclass + def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: if self.padding_mode != 'zeros': @@ -35,16 +41,27 @@ torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward repo_id = 'hexgrad/Kokoro-82M-v1.1-zh' -# MODEL_SR = 24000 +MODEL_SR = 24000 +TARGET_SR = 16000 +# How much silence to insert between paragraphs: 5000 is about 0.2 seconds +N_ZEROS = 20 model = None en_empty_pipeline = None -en_pipeline = None -zh_pipeline = None -en_voice_pt = None -zh_voice_pt = None en_voice = os.getenv('EN_VOICE', 'af_maple.pt') zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt') -model_dir = os.getenv('MODEL_DIR', '/models/hexgrad/Kokoro-82M-v1.1-zh') +model_dir = os.getenv('MODEL_DIR', '/model/hexgrad') +model_name = os.getenv('MODEL_NAME','kokoro-v1_1-zh.pth') + +# model_1_1_dir = os.path.join(model_dir, 'Kokoro-82M-v1.1-zh') +# model_1_0_dir = os.path.join(model_dir, 'Kokoro-82M') +# repo_id_1_0 = 'hexgrad/Kokoro-82M' + +@dataclass +class LanguagePipeline: + pipeline: KPipeline + voice_pt: str + +pipeline_dict: dict[str, LanguagePipeline] = {} def en_callable(text): if text == 'Kokoro': @@ -63,29 +80,80 @@ def speed_callable(len_ps): speed = 1 - (len_ps - 83) / 500 return speed -# def resample_audio(data: np.ndarray, original_rate: int, target_rate: int): -# ori_dtype = data.dtype -# # data = normalize_audio(data) -# number_of_samples = int(len(data) * float(target_rate) / original_rate) -# resampled_data = resample(data, number_of_samples) -# # resampled_data = normalize_audio(resampled_data) -# return resampled_data.astype(ori_dtype) -def audio_postprocess(audio: np.ndarray): +# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + # This will create a wave header then append the frame input + # It should be first on a streaming wav file + # Other frames better should not have it (else you will hear some artifacts each chunk start) + wav_buf = io.BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + + wav_buf.seek(0) + return wav_buf.read() + + +def resample_audio(data: np.ndarray, original_rate: int, target_rate: int): + ori_dtype = data.dtype + # data = normalize_audio(data) + number_of_samples = int(len(data) * float(target_rate) / original_rate) + resampled_data = resample(data, number_of_samples) + # resampled_data = normalize_audio(resampled_data) + return resampled_data.astype(ori_dtype) + +def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int): + audio = resample_audio(data, original_rate, target_rate) if audio.dtype == np.float32: audio = np.int16(audio * 32767) + audio = np.concatenate([audio, np.zeros(N_ZEROS, dtype=np.int16)]) return audio def init(): - global model, en_empty_pipeline, en_pipeline, zh_pipeline - global en_voice_pt, zh_voice_pt + global model, en_empty_pipeline + global model_1_0 + global pipeline_dict device = 'cuda' if torch.cuda.is_available() else 'cpu' - model = KModel(repo_id=repo_id, model=os.path.join(model_dir, 'kokoro-v1_1-zh.pth'), config=os.path.join(model_dir, 'config.json')).to(device).eval() + model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval() en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False) en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model) zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable) en_voice_pt = os.path.join(model_dir, 'voices', en_voice) zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice) + pipeline_dict['zh'] = LanguagePipeline(pipeline=zh_pipeline, voice_pt=zh_voice_pt) + pipeline_dict['en'] = LanguagePipeline(pipeline=en_pipeline, voice_pt=en_voice_pt) + + # v1.0 model for other languages + # model_1_0 = KModel(repo_id=repo_id_1_0, model=os.path.join(model_1_0_dir, 'kokoro-v1_0.pth'), config=os.path.join(model_1_0_dir, 'config.json')).to(device).eval() + # # es + # es_pipeline = KPipeline(lang_code='e', repo_id=repo_id_1_0, model=model_1_0) + # es_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ef_dora.pt') + # pipeline_dict['es'] = LanguagePipeline(pipeline=es_pipeline, voice_pt=es_voice_pt) + # # fr + # fr_pipeline = KPipeline(lang_code='f', repo_id=repo_id_1_0, model=model_1_0) + # fr_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ff_siwis.pt') + # pipeline_dict['fr'] = LanguagePipeline(pipeline=fr_pipeline, voice_pt=fr_voice_pt) + # # hi + # hi_pipeline = KPipeline(lang_code='h', repo_id=repo_id_1_0, model=model_1_0) + # hi_voice_pt = os.path.join(model_1_0_dir, 'voices', 'hf_alpha.pt') + # pipeline_dict['hi'] = LanguagePipeline(pipeline=hi_pipeline, voice_pt=hi_voice_pt) + # # it + # it_pipeline = KPipeline(lang_code='i', repo_id=repo_id_1_0, model=model_1_0) + # it_voice_pt = os.path.join(model_1_0_dir, 'voices', 'if_sara.pt') + # pipeline_dict['it'] = LanguagePipeline(pipeline=it_pipeline, voice_pt=it_voice_pt) + # # ja + # ja_pipeline = KPipeline(lang_code='j', repo_id=repo_id_1_0, model=model_1_0) + # ja_voice_pt = os.path.join(model_1_0_dir, 'voices', 'jf_alpha.pt') + # pipeline_dict['ja'] = LanguagePipeline(pipeline=ja_pipeline, voice_pt=ja_voice_pt) + # # pt + # pt_pipeline = KPipeline(lang_code='p', repo_id=repo_id_1_0, model=model_1_0) + # pt_voice_pt = os.path.join(model_1_0_dir, 'voices', 'pf_dora.pt') + # pipeline_dict['pt'] = LanguagePipeline(pipeline=pt_pipeline, voice_pt=pt_voice_pt) + + warmup() @asynccontextmanager async def lifespan(app: FastAPI): @@ -95,29 +163,90 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) +def warmup(): + zh_pipeline = pipeline_dict['zh'].pipeline + voice = pipeline_dict['zh'].voice_pt + generator = zh_pipeline(text="语音合成测试TTS。", voice=voice, speed=speed_callable) + for _ in generator: + pass + xml_namespace = "{http://www.w3.org/XML/1998/namespace}" -# return 24kHz pcm-16 +symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—' +def contains_words(text): + return any(char not in symbols for char in text) + +def cut_sentences(text) -> list[str]: + text = text.strip() + splits = re.split(r"([.;?!、。?!;])", text) + sentences = [] + for i in range(0, len(splits), 2): + if i + 1 < len(splits): + s = splits[i] + splits[i + 1] + else: + s = splits[i] + s = s.strip() + if s: + sentences.append(s) + return sentences + +LANGUAGE_ALIASES = { + 'z': 'zh', + 'a': 'en', + 'e': 'es', + 'f': 'fr', + 'h': 'hi', + 'i': 'it', + 'j': 'ja', + 'p': 'pt', +} + +@app.post("/") @app.post("/tts") -def generate(ssml: str = Body(...)): +def predict(ssml: str = Body(...), include_header: bool = False): try: root = ET.fromstring(ssml) voice_element = root.find(".//voice") if voice_element is not None: - text = voice_element.text.strip() + transcription = voice_element.text.strip() language = voice_element.get(f'{xml_namespace}lang', "zh").strip() + # voice_name = voice_element.get("name", "zh-f-soft-1").strip() else: return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) except ET.ParseError as e: return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) + + if not contains_words(transcription): + audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes() + if include_header: + audio_header = wave_header_chunk(sample_rate=TARGET_SR) + audio = audio_header + audio + return Response(audio, media_type='audio/wav') + + if language not in pipeline_dict: + if language in LANGUAGE_ALIASES: + language = LANGUAGE_ALIASES[language] + else: + return JSONResponse(status_code=400, content={"message": f"Language '{language}' not supported."}) + def streaming_generator(): - if language == 'en': - generator = en_pipeline(text=text, voice=en_voice_pt) - else: - generator = zh_pipeline(text=text, voice=zh_voice_pt, speed=speed_callable) - for (_, _, audio) in generator: - yield audio_postprocess(audio.numpy()).tobytes() + texts = cut_sentences(transcription) + has_yield = False + for text in texts: + if text.strip() and contains_words(text): + pipeline = pipeline_dict[language].pipeline + voice = pipeline_dict[language].voice_pt + if language == 'zh': + generator = pipeline(text=text, voice=voice, speed=speed_callable) + else: + generator = pipeline(text=text, voice=voice) + + for (_, _, audio) in generator: + if include_header and not has_yield: + has_yield = True + yield wave_header_chunk(sample_rate=TARGET_SR) + yield audio_postprocess(audio.numpy(), MODEL_SR, TARGET_SR).tobytes() return StreamingResponse(streaming_generator(), media_type='audio/wav') @@ -127,6 +256,19 @@ def generate(ssml: str = Body(...)): async def ready(): return JSONResponse(status_code=200, content={"status": "ok"}) +@app.get("/health_check") +async def health_check(): + try: + a = torch.ones(10, 20, dtype=torch.float32, device='cuda') + b = torch.ones(20, 10, dtype=torch.float32, device='cuda') + c = torch.matmul(a, b) + if c.sum() == 10 * 20 * 10: + return {"status": "ok"} + else: + raise HTTPException(status_code=503) + except Exception as e: + print(f'health_check failed') + raise HTTPException(status_code=503) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=80) diff --git a/bi_v100-kokoro/requirements_kokoro.txt b/bi_v100-kokoro/requirements_kokoro.txt index 69bc8c2..517d64b 100644 --- a/bi_v100-kokoro/requirements_kokoro.txt +++ b/bi_v100-kokoro/requirements_kokoro.txt @@ -2,4 +2,4 @@ kokoro>=0.8.2 misaki[zh]>=0.8.2 soundfile fastapi -uvicorn[standard] \ No newline at end of file +uvicorn[standard]