init kokoro
This commit is contained in:
@@ -6,4 +6,7 @@ RUN pip install -r requirements_kokoro.txt -c constraints_kokoro.txt \
|
|||||||
&& apt update \
|
&& apt update \
|
||||||
&& apt install -y espeak-ng
|
&& apt install -y espeak-ng
|
||||||
|
|
||||||
|
COPY ./en_core_web_sm-3.8.0.tar.gz .
|
||||||
|
RUN pip install --no-index en_core_web_sm-3.8.0.tar.gz
|
||||||
|
|
||||||
ENTRYPOINT ["/bin/bash", "launch_kokoro.sh"]
|
ENTRYPOINT ["/bin/bash", "launch_kokoro.sh"]
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
torch==2.1.0+corex.3.2.1
|
torch==2.1.0+corex.3.2.1
|
||||||
|
numpy==1.23.5
|
||||||
|
scipy==1.14.1
|
||||||
|
|||||||
BIN
bi_v100-kokoro/en_core_web_sm-3.8.0.tar.gz
Normal file
BIN
bi_v100-kokoro/en_core_web_sm-3.8.0.tar.gz
Normal file
Binary file not shown.
@@ -1,19 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
|
import io
|
||||||
|
|
||||||
from fastapi import FastAPI, Body
|
from fastapi import FastAPI, Response, Body, HTTPException
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
from kokoro import KPipeline, KModel
|
from kokoro import KPipeline, KModel
|
||||||
|
# import soundfile as sf
|
||||||
|
import wave
|
||||||
import numpy as np
|
import numpy as np
|
||||||
# from scipy.signal import resample
|
from scipy.signal import resample
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != 'zeros':
|
||||||
@@ -35,16 +41,27 @@ torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
|
|||||||
|
|
||||||
|
|
||||||
repo_id = 'hexgrad/Kokoro-82M-v1.1-zh'
|
repo_id = 'hexgrad/Kokoro-82M-v1.1-zh'
|
||||||
# MODEL_SR = 24000
|
MODEL_SR = 24000
|
||||||
|
TARGET_SR = 16000
|
||||||
|
# How much silence to insert between paragraphs: 5000 is about 0.2 seconds
|
||||||
|
N_ZEROS = 20
|
||||||
model = None
|
model = None
|
||||||
en_empty_pipeline = None
|
en_empty_pipeline = None
|
||||||
en_pipeline = None
|
|
||||||
zh_pipeline = None
|
|
||||||
en_voice_pt = None
|
|
||||||
zh_voice_pt = None
|
|
||||||
en_voice = os.getenv('EN_VOICE', 'af_maple.pt')
|
en_voice = os.getenv('EN_VOICE', 'af_maple.pt')
|
||||||
zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt')
|
zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt')
|
||||||
model_dir = os.getenv('MODEL_DIR', '/models/hexgrad/Kokoro-82M-v1.1-zh')
|
model_dir = os.getenv('MODEL_DIR', '/model/hexgrad')
|
||||||
|
model_name = os.getenv('MODEL_NAME','kokoro-v1_1-zh.pth')
|
||||||
|
|
||||||
|
# model_1_1_dir = os.path.join(model_dir, 'Kokoro-82M-v1.1-zh')
|
||||||
|
# model_1_0_dir = os.path.join(model_dir, 'Kokoro-82M')
|
||||||
|
# repo_id_1_0 = 'hexgrad/Kokoro-82M'
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LanguagePipeline:
|
||||||
|
pipeline: KPipeline
|
||||||
|
voice_pt: str
|
||||||
|
|
||||||
|
pipeline_dict: dict[str, LanguagePipeline] = {}
|
||||||
|
|
||||||
def en_callable(text):
|
def en_callable(text):
|
||||||
if text == 'Kokoro':
|
if text == 'Kokoro':
|
||||||
@@ -63,29 +80,80 @@ def speed_callable(len_ps):
|
|||||||
speed = 1 - (len_ps - 83) / 500
|
speed = 1 - (len_ps - 83) / 500
|
||||||
return speed
|
return speed
|
||||||
|
|
||||||
# def resample_audio(data: np.ndarray, original_rate: int, target_rate: int):
|
|
||||||
# ori_dtype = data.dtype
|
|
||||||
# # data = normalize_audio(data)
|
|
||||||
# number_of_samples = int(len(data) * float(target_rate) / original_rate)
|
|
||||||
# resampled_data = resample(data, number_of_samples)
|
|
||||||
# # resampled_data = normalize_audio(resampled_data)
|
|
||||||
# return resampled_data.astype(ori_dtype)
|
|
||||||
|
|
||||||
def audio_postprocess(audio: np.ndarray):
|
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
|
||||||
|
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||||
|
# This will create a wave header then append the frame input
|
||||||
|
# It should be first on a streaming wav file
|
||||||
|
# Other frames better should not have it (else you will hear some artifacts each chunk start)
|
||||||
|
wav_buf = io.BytesIO()
|
||||||
|
with wave.open(wav_buf, "wb") as vfout:
|
||||||
|
vfout.setnchannels(channels)
|
||||||
|
vfout.setsampwidth(sample_width)
|
||||||
|
vfout.setframerate(sample_rate)
|
||||||
|
vfout.writeframes(frame_input)
|
||||||
|
|
||||||
|
wav_buf.seek(0)
|
||||||
|
return wav_buf.read()
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio(data: np.ndarray, original_rate: int, target_rate: int):
|
||||||
|
ori_dtype = data.dtype
|
||||||
|
# data = normalize_audio(data)
|
||||||
|
number_of_samples = int(len(data) * float(target_rate) / original_rate)
|
||||||
|
resampled_data = resample(data, number_of_samples)
|
||||||
|
# resampled_data = normalize_audio(resampled_data)
|
||||||
|
return resampled_data.astype(ori_dtype)
|
||||||
|
|
||||||
|
def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int):
|
||||||
|
audio = resample_audio(data, original_rate, target_rate)
|
||||||
if audio.dtype == np.float32:
|
if audio.dtype == np.float32:
|
||||||
audio = np.int16(audio * 32767)
|
audio = np.int16(audio * 32767)
|
||||||
|
audio = np.concatenate([audio, np.zeros(N_ZEROS, dtype=np.int16)])
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
global model, en_empty_pipeline, en_pipeline, zh_pipeline
|
global model, en_empty_pipeline
|
||||||
global en_voice_pt, zh_voice_pt
|
global model_1_0
|
||||||
|
global pipeline_dict
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
model = KModel(repo_id=repo_id, model=os.path.join(model_dir, 'kokoro-v1_1-zh.pth'), config=os.path.join(model_dir, 'config.json')).to(device).eval()
|
model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval()
|
||||||
en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False)
|
en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False)
|
||||||
en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)
|
en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)
|
||||||
zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable)
|
zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable)
|
||||||
en_voice_pt = os.path.join(model_dir, 'voices', en_voice)
|
en_voice_pt = os.path.join(model_dir, 'voices', en_voice)
|
||||||
zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice)
|
zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice)
|
||||||
|
pipeline_dict['zh'] = LanguagePipeline(pipeline=zh_pipeline, voice_pt=zh_voice_pt)
|
||||||
|
pipeline_dict['en'] = LanguagePipeline(pipeline=en_pipeline, voice_pt=en_voice_pt)
|
||||||
|
|
||||||
|
# v1.0 model for other languages
|
||||||
|
# model_1_0 = KModel(repo_id=repo_id_1_0, model=os.path.join(model_1_0_dir, 'kokoro-v1_0.pth'), config=os.path.join(model_1_0_dir, 'config.json')).to(device).eval()
|
||||||
|
# # es
|
||||||
|
# es_pipeline = KPipeline(lang_code='e', repo_id=repo_id_1_0, model=model_1_0)
|
||||||
|
# es_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ef_dora.pt')
|
||||||
|
# pipeline_dict['es'] = LanguagePipeline(pipeline=es_pipeline, voice_pt=es_voice_pt)
|
||||||
|
# # fr
|
||||||
|
# fr_pipeline = KPipeline(lang_code='f', repo_id=repo_id_1_0, model=model_1_0)
|
||||||
|
# fr_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ff_siwis.pt')
|
||||||
|
# pipeline_dict['fr'] = LanguagePipeline(pipeline=fr_pipeline, voice_pt=fr_voice_pt)
|
||||||
|
# # hi
|
||||||
|
# hi_pipeline = KPipeline(lang_code='h', repo_id=repo_id_1_0, model=model_1_0)
|
||||||
|
# hi_voice_pt = os.path.join(model_1_0_dir, 'voices', 'hf_alpha.pt')
|
||||||
|
# pipeline_dict['hi'] = LanguagePipeline(pipeline=hi_pipeline, voice_pt=hi_voice_pt)
|
||||||
|
# # it
|
||||||
|
# it_pipeline = KPipeline(lang_code='i', repo_id=repo_id_1_0, model=model_1_0)
|
||||||
|
# it_voice_pt = os.path.join(model_1_0_dir, 'voices', 'if_sara.pt')
|
||||||
|
# pipeline_dict['it'] = LanguagePipeline(pipeline=it_pipeline, voice_pt=it_voice_pt)
|
||||||
|
# # ja
|
||||||
|
# ja_pipeline = KPipeline(lang_code='j', repo_id=repo_id_1_0, model=model_1_0)
|
||||||
|
# ja_voice_pt = os.path.join(model_1_0_dir, 'voices', 'jf_alpha.pt')
|
||||||
|
# pipeline_dict['ja'] = LanguagePipeline(pipeline=ja_pipeline, voice_pt=ja_voice_pt)
|
||||||
|
# # pt
|
||||||
|
# pt_pipeline = KPipeline(lang_code='p', repo_id=repo_id_1_0, model=model_1_0)
|
||||||
|
# pt_voice_pt = os.path.join(model_1_0_dir, 'voices', 'pf_dora.pt')
|
||||||
|
# pipeline_dict['pt'] = LanguagePipeline(pipeline=pt_pipeline, voice_pt=pt_voice_pt)
|
||||||
|
|
||||||
|
warmup()
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -95,29 +163,90 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
def warmup():
|
||||||
|
zh_pipeline = pipeline_dict['zh'].pipeline
|
||||||
|
voice = pipeline_dict['zh'].voice_pt
|
||||||
|
generator = zh_pipeline(text="语音合成测试TTS。", voice=voice, speed=speed_callable)
|
||||||
|
for _ in generator:
|
||||||
|
pass
|
||||||
|
|
||||||
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
|
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
|
||||||
|
|
||||||
# return 24kHz pcm-16
|
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
|
||||||
|
def contains_words(text):
|
||||||
|
return any(char not in symbols for char in text)
|
||||||
|
|
||||||
|
def cut_sentences(text) -> list[str]:
|
||||||
|
text = text.strip()
|
||||||
|
splits = re.split(r"([.;?!、。?!;])", text)
|
||||||
|
sentences = []
|
||||||
|
for i in range(0, len(splits), 2):
|
||||||
|
if i + 1 < len(splits):
|
||||||
|
s = splits[i] + splits[i + 1]
|
||||||
|
else:
|
||||||
|
s = splits[i]
|
||||||
|
s = s.strip()
|
||||||
|
if s:
|
||||||
|
sentences.append(s)
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
LANGUAGE_ALIASES = {
|
||||||
|
'z': 'zh',
|
||||||
|
'a': 'en',
|
||||||
|
'e': 'es',
|
||||||
|
'f': 'fr',
|
||||||
|
'h': 'hi',
|
||||||
|
'i': 'it',
|
||||||
|
'j': 'ja',
|
||||||
|
'p': 'pt',
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/")
|
||||||
@app.post("/tts")
|
@app.post("/tts")
|
||||||
def generate(ssml: str = Body(...)):
|
def predict(ssml: str = Body(...), include_header: bool = False):
|
||||||
try:
|
try:
|
||||||
root = ET.fromstring(ssml)
|
root = ET.fromstring(ssml)
|
||||||
voice_element = root.find(".//voice")
|
voice_element = root.find(".//voice")
|
||||||
if voice_element is not None:
|
if voice_element is not None:
|
||||||
text = voice_element.text.strip()
|
transcription = voice_element.text.strip()
|
||||||
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
|
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
|
||||||
|
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
|
||||||
else:
|
else:
|
||||||
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
|
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
|
||||||
except ET.ParseError as e:
|
except ET.ParseError as e:
|
||||||
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
|
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
|
||||||
|
|
||||||
|
if not contains_words(transcription):
|
||||||
|
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
|
||||||
|
if include_header:
|
||||||
|
audio_header = wave_header_chunk(sample_rate=TARGET_SR)
|
||||||
|
audio = audio_header + audio
|
||||||
|
return Response(audio, media_type='audio/wav')
|
||||||
|
|
||||||
|
if language not in pipeline_dict:
|
||||||
|
if language in LANGUAGE_ALIASES:
|
||||||
|
language = LANGUAGE_ALIASES[language]
|
||||||
|
else:
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"Language '{language}' not supported."})
|
||||||
|
|
||||||
|
|
||||||
def streaming_generator():
|
def streaming_generator():
|
||||||
if language == 'en':
|
texts = cut_sentences(transcription)
|
||||||
generator = en_pipeline(text=text, voice=en_voice_pt)
|
has_yield = False
|
||||||
else:
|
for text in texts:
|
||||||
generator = zh_pipeline(text=text, voice=zh_voice_pt, speed=speed_callable)
|
if text.strip() and contains_words(text):
|
||||||
for (_, _, audio) in generator:
|
pipeline = pipeline_dict[language].pipeline
|
||||||
yield audio_postprocess(audio.numpy()).tobytes()
|
voice = pipeline_dict[language].voice_pt
|
||||||
|
if language == 'zh':
|
||||||
|
generator = pipeline(text=text, voice=voice, speed=speed_callable)
|
||||||
|
else:
|
||||||
|
generator = pipeline(text=text, voice=voice)
|
||||||
|
|
||||||
|
for (_, _, audio) in generator:
|
||||||
|
if include_header and not has_yield:
|
||||||
|
has_yield = True
|
||||||
|
yield wave_header_chunk(sample_rate=TARGET_SR)
|
||||||
|
yield audio_postprocess(audio.numpy(), MODEL_SR, TARGET_SR).tobytes()
|
||||||
|
|
||||||
return StreamingResponse(streaming_generator(), media_type='audio/wav')
|
return StreamingResponse(streaming_generator(), media_type='audio/wav')
|
||||||
|
|
||||||
@@ -127,6 +256,19 @@ def generate(ssml: str = Body(...)):
|
|||||||
async def ready():
|
async def ready():
|
||||||
return JSONResponse(status_code=200, content={"status": "ok"})
|
return JSONResponse(status_code=200, content={"status": "ok"})
|
||||||
|
|
||||||
|
@app.get("/health_check")
|
||||||
|
async def health_check():
|
||||||
|
try:
|
||||||
|
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
|
||||||
|
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
|
||||||
|
c = torch.matmul(a, b)
|
||||||
|
if c.sum() == 10 * 20 * 10:
|
||||||
|
return {"status": "ok"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=503)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'health_check failed')
|
||||||
|
raise HTTPException(status_code=503)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="0.0.0.0", port=80)
|
uvicorn.run(app, host="0.0.0.0", port=80)
|
||||||
|
|||||||
@@ -2,4 +2,4 @@ kokoro>=0.8.2
|
|||||||
misaki[zh]>=0.8.2
|
misaki[zh]>=0.8.2
|
||||||
soundfile
|
soundfile
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
|
|||||||
Reference in New Issue
Block a user