init muxi
This commit is contained in:
21
metaX-C500-kokoro/Dockerfile_kokoro
Normal file
21
metaX-C500-kokoro/Dockerfile_kokoro
Normal file
@@ -0,0 +1,21 @@
|
||||
FROM git.modelhub.org.cn:9443/enginex-metax/maca-c500-pytorch:2.33.0.6-torch2.6-py310-ubuntu24.04-amd64
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ENV CONDA_DIR=/opt/conda
|
||||
ENV PATH=${CONDA_DIR}/bin:${PATH}
|
||||
|
||||
USER root
|
||||
RUN set -eux; \
|
||||
chmod 1777 /tmp; \
|
||||
mkdir -p /var/tmp/apt-tmp && chmod 1777 /var/tmp/apt-tmp; \
|
||||
apt-get -o Dir::Temp::=/var/tmp/apt-tmp update && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get -o Dir::Temp::=/var/tmp/apt-tmp install -y --no-install-recommends espeak-ng && \
|
||||
rm -rf /var/lib/apt/lists/* /var/tmp/apt-tmp
|
||||
|
||||
COPY requirements_kokoro.txt constraints_kokoro.txt kokoro_server.py en_core_web_sm-3.8.0.tar.gz /workspace/
|
||||
RUN pip install -r requirements_kokoro.txt -c constraints_kokoro.txt
|
||||
RUN pip install en_core_web_sm-3.8.0.tar.gz
|
||||
|
||||
COPY launch_kokoro.sh /workspace/
|
||||
ENTRYPOINT ["/bin/bash", "launch_kokoro.sh"]
|
||||
46
metaX-C500-kokoro/README.md
Normal file
46
metaX-C500-kokoro/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Kokoro-TTS
|
||||
|
||||
本项目基于 **Kokoro** 模型封装,提供简洁的 Docker 部署方式,支持 **SSML 输入**,输出 **PCM 原始音频**,可用于语音合成。
|
||||
|
||||
---
|
||||
|
||||
## Quickstart
|
||||
|
||||
### 1. 安装镜像
|
||||
```bash
|
||||
docker build -t tts:kokoro . -f Dockerfile_kokoro
|
||||
```
|
||||
|
||||
### 2. 启动服务
|
||||
```bash
|
||||
metax-docker run -it --rm \
|
||||
-v /models/Kokoro-82M-v1.1-zh:/mnt/models \
|
||||
--gpus=[2] \
|
||||
-p 8080:80 \
|
||||
-e MODEL_DIR=/mnt/models \
|
||||
-e MODEL_NAME=kokoro-v1_1-zh.pth \
|
||||
tts:kokoro
|
||||
```
|
||||
|
||||
参数说明:
|
||||
- `MODEL_DIR`:模型所在目录(挂载到容器内 `/mnt/models`)
|
||||
- `MODEL_NAME`:加载的模型文件名(通常为 `.safetensors`)
|
||||
- `-p 8080:80`:将容器内服务端口映射到宿主机 `8080`
|
||||
|
||||
### 3. 测试服务
|
||||
```bash
|
||||
curl --request POST "http://localhost:8080/tts" \
|
||||
--header 'Content-Type: application/ssml+xml' \
|
||||
--header 'User-Agent: curl' \
|
||||
--data-raw '<speak version="1.0" xml:lang="zh">
|
||||
<voice xml:lang="zh" xml:gender="Female" name="zh">
|
||||
今天天气很好,不知道明天天气怎么样。
|
||||
</voice>
|
||||
</speak>' \
|
||||
--output sound.pcm
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
- 无 Patch 能运行,但是生成的音频会有噪声
|
||||
- Patch: 将 decoder/istftnet 固定为 CPU FP32(禁用 AMP/TF32),修复 GPU “打字机”噪声
|
||||
1
metaX-C500-kokoro/constraints_kokoro.txt
Normal file
1
metaX-C500-kokoro/constraints_kokoro.txt
Normal file
@@ -0,0 +1 @@
|
||||
numpy==1.26.4
|
||||
BIN
metaX-C500-kokoro/en_core_web_sm-3.8.0.tar.gz
Normal file
BIN
metaX-C500-kokoro/en_core_web_sm-3.8.0.tar.gz
Normal file
Binary file not shown.
311
metaX-C500-kokoro/kokoro_server.py
Normal file
311
metaX-C500-kokoro/kokoro_server.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import os
|
||||
import io
|
||||
|
||||
from fastapi import FastAPI, Response, Body, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from kokoro import KPipeline, KModel
|
||||
# import soundfile as sf
|
||||
import wave
|
||||
import numpy as np
|
||||
from scipy.signal import resample
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional, List
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
repo_id = 'hexgrad/Kokoro-82M-v1.1-zh'
|
||||
MODEL_SR = 24000
|
||||
TARGET_SR = 16000
|
||||
# How much silence to insert between paragraphs: 5000 is about 0.2 seconds
|
||||
N_ZEROS = 20
|
||||
model = None
|
||||
en_empty_pipeline = None
|
||||
en_voice = os.getenv('EN_VOICE', 'af_maple.pt')
|
||||
zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt')
|
||||
model_dir = os.getenv('MODEL_DIR', '/model/hexgrad')
|
||||
model_name = os.getenv('MODEL_NAME','kokoro-v1_1-zh.pth')
|
||||
|
||||
# model_1_1_dir = os.path.join(model_dir, 'Kokoro-82M-v1.1-zh')
|
||||
# model_1_0_dir = os.path.join(model_dir, 'Kokoro-82M')
|
||||
# repo_id_1_0 = 'hexgrad/Kokoro-82M'
|
||||
|
||||
@dataclass
|
||||
class LanguagePipeline:
|
||||
pipeline: KPipeline
|
||||
voice_pt: str
|
||||
|
||||
pipeline_dict: dict[str, LanguagePipeline] = {}
|
||||
|
||||
def en_callable(text):
|
||||
if text == 'Kokoro':
|
||||
return 'kˈOkəɹO'
|
||||
elif text == 'Sol':
|
||||
return 'sˈOl'
|
||||
return next(en_empty_pipeline(text)).phonemes
|
||||
|
||||
# HACK: Mitigate rushing caused by lack of training data beyond ~100 tokens
|
||||
# Simple piecewise linear fn that decreases speed as len_ps increases
|
||||
def speed_callable(len_ps):
|
||||
speed = 0.8
|
||||
if len_ps <= 83:
|
||||
speed = 1
|
||||
elif len_ps < 183:
|
||||
speed = 1 - (len_ps - 83) / 500
|
||||
return speed
|
||||
|
||||
|
||||
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
|
||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||
# This will create a wave header then append the frame input
|
||||
# It should be first on a streaming wav file
|
||||
# Other frames better should not have it (else you will hear some artifacts each chunk start)
|
||||
wav_buf = io.BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
def resample_audio(data: np.ndarray, original_rate: int, target_rate: int):
|
||||
ori_dtype = data.dtype
|
||||
# data = normalize_audio(data)
|
||||
number_of_samples = int(len(data) * float(target_rate) / original_rate)
|
||||
resampled_data = resample(data, number_of_samples)
|
||||
# resampled_data = normalize_audio(resampled_data)
|
||||
return resampled_data.astype(ori_dtype)
|
||||
|
||||
def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int):
|
||||
audio = resample_audio(data, original_rate, target_rate)
|
||||
if audio.dtype == np.float32:
|
||||
audio = np.int16(audio * 32767)
|
||||
audio = np.concatenate([audio, np.zeros(N_ZEROS, dtype=np.int16)])
|
||||
return audio
|
||||
|
||||
|
||||
# ================== decoder/istftnet 补丁 ==================
|
||||
|
||||
def _to_cpu_fp32(obj):
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().to("cpu", dtype=torch.float32)
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return type(obj)(_to_cpu_fp32(x) for x in obj)
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu_fp32(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
def patch_decoder(model, device: str):
|
||||
decoder = getattr(model, "decoder", None)
|
||||
if decoder is None:
|
||||
raise RuntimeError("未找到 model.decoder,请 print(model) 确认实际模块名。")
|
||||
|
||||
decoder.eval().to(device).float()
|
||||
try: torch.nn.utils.remove_weight_norm(decoder)
|
||||
except Exception: pass
|
||||
|
||||
for p in decoder.parameters():
|
||||
p.requires_grad = False
|
||||
if p.dtype != torch.float32: p.data = p.data.float()
|
||||
for n, b in decoder.named_buffers():
|
||||
if b.dtype != torch.float32: setattr(decoder, n, b.float())
|
||||
|
||||
orig_forward = decoder.forward
|
||||
|
||||
def forward_patched(*args, **kwargs):
|
||||
# 关键:decoder 运行时,强制 FP32、并保证输入与 decoder 在同一设备
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
if device == "cpu":
|
||||
args_ = _to_cpu_fp32(args); kwargs_ = _to_cpu_fp32(kwargs)
|
||||
out = orig_forward(*args_, **kwargs_)
|
||||
return _to_cpu_fp32(out)
|
||||
else: # device == "cuda"
|
||||
def to_gpu_fp32(x):
|
||||
if torch.is_tensor(x):
|
||||
return x.detach().to("cuda", dtype=torch.float32)
|
||||
if isinstance(x, (list, tuple)):
|
||||
return type(x)(to_gpu_fp32(t) for t in x)
|
||||
if isinstance(x, dict):
|
||||
return {k: to_gpu_fp32(v) for k, v in x.items()}
|
||||
return x
|
||||
args_ = to_gpu_fp32(args); kwargs_ = to_gpu_fp32(kwargs)
|
||||
out = orig_forward(*args_, **kwargs_)
|
||||
if torch.is_tensor(out):
|
||||
return out.detach().to(device, dtype=torch.float32)
|
||||
return out
|
||||
|
||||
decoder.forward = forward_patched
|
||||
|
||||
|
||||
|
||||
def init():
|
||||
global model, en_empty_pipeline
|
||||
global model_1_0
|
||||
global pipeline_dict
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval()
|
||||
patch_decoder(model, "cpu")
|
||||
|
||||
en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False)
|
||||
en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)
|
||||
zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable)
|
||||
en_voice_pt = os.path.join(model_dir, 'voices', en_voice)
|
||||
zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice)
|
||||
pipeline_dict['zh'] = LanguagePipeline(pipeline=zh_pipeline, voice_pt=zh_voice_pt)
|
||||
pipeline_dict['en'] = LanguagePipeline(pipeline=en_pipeline, voice_pt=en_voice_pt)
|
||||
|
||||
# v1.0 model for other languages
|
||||
# model_1_0 = KModel(repo_id=repo_id_1_0, model=os.path.join(model_1_0_dir, 'kokoro-v1_0.pth'), config=os.path.join(model_1_0_dir, 'config.json')).to(device).eval()
|
||||
# # es
|
||||
# es_pipeline = KPipeline(lang_code='e', repo_id=repo_id_1_0, model=model_1_0)
|
||||
# es_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ef_dora.pt')
|
||||
# pipeline_dict['es'] = LanguagePipeline(pipeline=es_pipeline, voice_pt=es_voice_pt)
|
||||
# # fr
|
||||
# fr_pipeline = KPipeline(lang_code='f', repo_id=repo_id_1_0, model=model_1_0)
|
||||
# fr_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ff_siwis.pt')
|
||||
# pipeline_dict['fr'] = LanguagePipeline(pipeline=fr_pipeline, voice_pt=fr_voice_pt)
|
||||
# # hi
|
||||
# hi_pipeline = KPipeline(lang_code='h', repo_id=repo_id_1_0, model=model_1_0)
|
||||
# hi_voice_pt = os.path.join(model_1_0_dir, 'voices', 'hf_alpha.pt')
|
||||
# pipeline_dict['hi'] = LanguagePipeline(pipeline=hi_pipeline, voice_pt=hi_voice_pt)
|
||||
# # it
|
||||
# it_pipeline = KPipeline(lang_code='i', repo_id=repo_id_1_0, model=model_1_0)
|
||||
# it_voice_pt = os.path.join(model_1_0_dir, 'voices', 'if_sara.pt')
|
||||
# pipeline_dict['it'] = LanguagePipeline(pipeline=it_pipeline, voice_pt=it_voice_pt)
|
||||
# # ja
|
||||
# ja_pipeline = KPipeline(lang_code='j', repo_id=repo_id_1_0, model=model_1_0)
|
||||
# ja_voice_pt = os.path.join(model_1_0_dir, 'voices', 'jf_alpha.pt')
|
||||
# pipeline_dict['ja'] = LanguagePipeline(pipeline=ja_pipeline, voice_pt=ja_voice_pt)
|
||||
# # pt
|
||||
# pt_pipeline = KPipeline(lang_code='p', repo_id=repo_id_1_0, model=model_1_0)
|
||||
# pt_voice_pt = os.path.join(model_1_0_dir, 'voices', 'pf_dora.pt')
|
||||
# pipeline_dict['pt'] = LanguagePipeline(pipeline=pt_pipeline, voice_pt=pt_voice_pt)
|
||||
|
||||
warmup()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
init()
|
||||
yield
|
||||
pass
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
def warmup():
|
||||
zh_pipeline = pipeline_dict['zh'].pipeline
|
||||
voice = pipeline_dict['zh'].voice_pt
|
||||
generator = zh_pipeline(text="语音合成测试TTS。", voice=voice, speed=speed_callable)
|
||||
for _ in generator:
|
||||
pass
|
||||
|
||||
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
|
||||
|
||||
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
|
||||
def contains_words(text):
|
||||
return any(char not in symbols for char in text)
|
||||
|
||||
def cut_sentences(text) -> list[str]:
|
||||
text = text.strip()
|
||||
splits = re.split(r"([.;?!、。?!;])", text)
|
||||
sentences = []
|
||||
for i in range(0, len(splits), 2):
|
||||
if i + 1 < len(splits):
|
||||
s = splits[i] + splits[i + 1]
|
||||
else:
|
||||
s = splits[i]
|
||||
s = s.strip()
|
||||
if s:
|
||||
sentences.append(s)
|
||||
return sentences
|
||||
|
||||
LANGUAGE_ALIASES = {
|
||||
'z': 'zh',
|
||||
'a': 'en',
|
||||
'e': 'es',
|
||||
'f': 'fr',
|
||||
'h': 'hi',
|
||||
'i': 'it',
|
||||
'j': 'ja',
|
||||
'p': 'pt',
|
||||
}
|
||||
|
||||
@app.post("/")
|
||||
@app.post("/tts")
|
||||
def predict(ssml: str = Body(...), include_header: bool = False):
|
||||
try:
|
||||
root = ET.fromstring(ssml)
|
||||
voice_element = root.find(".//voice")
|
||||
if voice_element is not None:
|
||||
transcription = voice_element.text.strip()
|
||||
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
|
||||
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
|
||||
except ET.ParseError as e:
|
||||
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
|
||||
|
||||
if not contains_words(transcription):
|
||||
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
|
||||
if include_header:
|
||||
audio_header = wave_header_chunk(sample_rate=TARGET_SR)
|
||||
audio = audio_header + audio
|
||||
return Response(audio, media_type='audio/wav')
|
||||
|
||||
if language not in pipeline_dict:
|
||||
if language in LANGUAGE_ALIASES:
|
||||
language = LANGUAGE_ALIASES[language]
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": f"Language '{language}' not supported."})
|
||||
|
||||
|
||||
def streaming_generator():
|
||||
texts = cut_sentences(transcription)
|
||||
has_yield = False
|
||||
for text in texts:
|
||||
if text.strip() and contains_words(text):
|
||||
pipeline = pipeline_dict[language].pipeline
|
||||
voice = pipeline_dict[language].voice_pt
|
||||
if language == 'zh':
|
||||
generator = pipeline(text=text, voice=voice, speed=speed_callable)
|
||||
else:
|
||||
generator = pipeline(text=text, voice=voice)
|
||||
|
||||
for (_, _, audio) in generator:
|
||||
if include_header and not has_yield:
|
||||
has_yield = True
|
||||
yield wave_header_chunk(sample_rate=TARGET_SR)
|
||||
yield audio_postprocess(audio.numpy(), MODEL_SR, TARGET_SR).tobytes()
|
||||
|
||||
return StreamingResponse(streaming_generator(), media_type='audio/wav')
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/ready")
|
||||
async def ready():
|
||||
return JSONResponse(status_code=200, content={"status": "ok"})
|
||||
|
||||
@app.get("/health_check")
|
||||
async def health_check():
|
||||
try:
|
||||
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
|
||||
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
|
||||
c = torch.matmul(a, b)
|
||||
if c.sum() == 10 * 20 * 10:
|
||||
return {"status": "ok"}
|
||||
else:
|
||||
raise HTTPException(status_code=503)
|
||||
except Exception as e:
|
||||
print(f'health_check failed')
|
||||
raise HTTPException(status_code=503)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=80)
|
||||
4
metaX-C500-kokoro/launch_kokoro.sh
Executable file
4
metaX-C500-kokoro/launch_kokoro.sh
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
python3 kokoro_server.py
|
||||
|
||||
6
metaX-C500-kokoro/requirements_kokoro.txt
Normal file
6
metaX-C500-kokoro/requirements_kokoro.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
kokoro>=0.8.2
|
||||
misaki[zh]>=0.8.2
|
||||
soundfile
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
setuptools>=40.8.0
|
||||
Reference in New Issue
Block a user