init ascend tts

This commit is contained in:
2025-09-05 11:27:43 +08:00
parent d53ac91bb6
commit b92a65b0fa
602 changed files with 590901 additions and 1 deletions

View File

@@ -0,0 +1,20 @@
FROM quay.io/ascend/vllm-ascend:v0.10.0rc1
WORKDIR /workspace
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
echo "deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ jammy main restricted universe multiverse" > /etc/apt/sources.list && \
echo "deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ jammy-updates main restricted universe multiverse" >> /etc/apt/sources.list && \
echo "deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ jammy-backports main restricted universe multiverse" >> /etc/apt/sources.list && \
echo "deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ jammy-security main restricted universe multiverse" >> /etc/apt/sources.list && \
apt-get update && \
apt-get install -y espeak-ng && \
rm -rf /var/lib/apt/lists/*
COPY requirements_kokoro.txt constraints_kokoro.txt kokoro_server.py en_core_web_sm-3.8.0.tar.gz /workspace/
RUN pip install -r requirements_kokoro.txt -c constraints_kokoro.txt
RUN pip install --no-index en_core_web_sm-3.8.0.tar.gz
COPY launch_kokoro.sh /workspace/
ENTRYPOINT ["/bin/bash", "launch_kokoro.sh"]

View File

@@ -0,0 +1,53 @@
# Kokoro-TTS
本项目基于 **Kokoro** 模型封装,提供简洁的 Docker 部署方式,支持 **SSML 输入**,输出 **PCM 原始音频**,可用于语音合成。
---
## Quickstart
### 1. 安装镜像
```bash
docker build -t tts:kokoro . -f Dockerfile_kokoro
```
### 2. 启动服务
```bash
docker run -it --rm \
-v /models/Kokoro-82M-v1.1-zh:/mnt/models \
-e ASCEND_VISIBLE_DEVICES=1 \
--device /dev/davinci2:/dev/davinci0 \
--device /dev/davinci_manager \
--device /dev/devmm_svm \
--device /dev/hisi_hdc \
-v /usr/local/dcmi:/usr/local/dcmi \
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
-v /etc/ascend_install.info:/etc/ascend_install.info \
--privileged \
-p 8080:80 \
-e MODEL_DIR=/mnt/models \
-e MODEL_NAME=kokoro-v1_1-zh.pth \
tts:kokoro
```
参数说明:
- `MODEL_DIR`:模型所在目录(挂载到容器内 `/mnt/models`
- `MODEL_NAME`:加载的模型文件名(通常为 `.safetensors`
- `-p 8080:80`:将容器内服务端口映射到宿主机 `8080`
### 3. 测试服务
```bash
curl --request POST "http://localhost:8080/tts" \
--header 'Content-Type: application/ssml+xml' \
--header 'User-Agent: curl' \
--data-raw '<speak version="1.0" xml:lang="zh">
<voice xml:lang="zh" xml:gender="Female" name="zh">
今天天气很好,不知道明天天气怎么样。
</voice>
</speak>' \
--output sound.pcm
```
---

View File

@@ -0,0 +1,2 @@
torch==2.7.1
numpy==1.26.4

Binary file not shown.

View File

@@ -0,0 +1,318 @@
import os
import io
from fastapi import FastAPI, Response, Body, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import xml.etree.ElementTree as ET
from kokoro import KPipeline, KModel
# import soundfile as sf
import wave
import numpy as np
from scipy.signal import resample
import torch
from torch import Tensor
from torch.nn import functional as F
from typing import Optional, List
import re
from dataclasses import dataclass
# --- /safe_abs_angle_patch ---
__ORIG_TORCH_ABS__ = torch.abs
__ORIG_TORCH_ANGLE__ = getattr(torch, "angle", None)
def _safe_mag_from_complex(x: torch.Tensor) -> torch.Tensor:
# 用 view_as_real 拆分为 [re, im],并在 float32 上计算,避免 NPU 复数内核
# 不做 .cpu(),保持在 NPU 上运行
if torch.is_complex(x):
ri = torch.view_as_real(x.to(torch.complex64)) # (..., 2)
re = ri[..., 0].float()
im = ri[..., 1].float()
else:
# 某些算子可能直接返回 (..., 2) 的实张量作为复数 [re, im]
# 兼容这种情况
if x.size(-1) == 2:
re = x[..., 0].float()
im = x[..., 1].float()
else:
# 非复数直接走原 abs
return __ORIG_TORCH_ABS__(x)
return torch.sqrt(torch.clamp(re * re + im * im, min=1e-12))
def _safe_angle_from_complex(x: torch.Tensor) -> torch.Tensor:
if torch.is_complex(x):
ri = torch.view_as_real(x.to(torch.complex64))
re = ri[..., 0].float()
im = ri[..., 1].float()
else:
if x.size(-1) == 2:
re = x[..., 0].float()
im = x[..., 1].float()
else:
# 非复数直接走原 angle若存在否则返回 0
return __ORIG_TORCH_ANGLE__(x) if __ORIG_TORCH_ANGLE__ else torch.zeros_like(x)
return torch.atan2(im, re)
def _patched_abs(x):
# 仅在复数/疑似复数时使用安全实现,否则用原生
try:
if torch.is_complex(x) or (x.dtype.is_floating and x.ndim >= 1 and x.size(-1) == 2):
return _safe_mag_from_complex(x)
return __ORIG_TORCH_ABS__(x)
except Exception:
# 兜底:任何异常回退到原生
return __ORIG_TORCH_ABS__(x)
def _patched_angle(x):
try:
if torch.is_complex(x) or (x.dtype.is_floating and x.ndim >= 1 and x.size(-1) == 2):
return _safe_angle_from_complex(x)
# 没有 torch.angle 的旧版本直接返回 0
return __ORIG_TORCH_ANGLE__(x) if __ORIG_TORCH_ANGLE__ else torch.zeros_like(x)
except Exception:
return __ORIG_TORCH_ANGLE__(x) if __ORIG_TORCH_ANGLE__ else torch.zeros_like(x)
# 安装补丁(全局生效)
torch.abs = _patched_abs
torch.angle = _patched_angle
# --- /safe_abs_angle_patch ---
repo_id = 'hexgrad/Kokoro-82M-v1.1-zh'
MODEL_SR = 24000
TARGET_SR = 16000
# How much silence to insert between paragraphs: 5000 is about 0.2 seconds
N_ZEROS = 20
model = None
en_empty_pipeline = None
en_voice = os.getenv('EN_VOICE', 'af_maple.pt')
zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt')
model_dir = os.getenv('MODEL_DIR', '/model/hexgrad')
model_name = os.getenv('MODEL_NAME','kokoro-v1_1-zh.pth')
# model_1_1_dir = os.path.join(model_dir, 'Kokoro-82M-v1.1-zh')
# model_1_0_dir = os.path.join(model_dir, 'Kokoro-82M')
# repo_id_1_0 = 'hexgrad/Kokoro-82M'
@dataclass
class LanguagePipeline:
pipeline: KPipeline
voice_pt: str
pipeline_dict: dict[str, LanguagePipeline] = {}
def en_callable(text):
if text == 'Kokoro':
return 'kˈOkəɹO'
elif text == 'Sol':
return 'sˈOl'
return next(en_empty_pipeline(text)).phonemes
# HACK: Mitigate rushing caused by lack of training data beyond ~100 tokens
# Simple piecewise linear fn that decreases speed as len_ps increases
def speed_callable(len_ps):
speed = 0.8
if len_ps <= 83:
speed = 1
elif len_ps < 183:
speed = 1 - (len_ps - 83) / 500
return speed
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
# This will create a wave header then append the frame input
# It should be first on a streaming wav file
# Other frames better should not have it (else you will hear some artifacts each chunk start)
wav_buf = io.BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()
def resample_audio(data: np.ndarray, original_rate: int, target_rate: int):
ori_dtype = data.dtype
# data = normalize_audio(data)
number_of_samples = int(len(data) * float(target_rate) / original_rate)
resampled_data = resample(data, number_of_samples)
# resampled_data = normalize_audio(resampled_data)
return resampled_data.astype(ori_dtype)
def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int):
audio = resample_audio(data, original_rate, target_rate)
if audio.dtype == np.float32:
audio = np.int16(audio * 32767)
audio = np.concatenate([audio, np.zeros(N_ZEROS, dtype=np.int16)])
return audio
def init():
global model, en_empty_pipeline
global model_1_0
global pipeline_dict
device = 'npu' #'cuda' if torch.cuda.is_available() else 'cpu'
model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval()
en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False)
en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)
zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable)
en_voice_pt = os.path.join(model_dir, 'voices', en_voice)
zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice)
pipeline_dict['zh'] = LanguagePipeline(pipeline=zh_pipeline, voice_pt=zh_voice_pt)
pipeline_dict['en'] = LanguagePipeline(pipeline=en_pipeline, voice_pt=en_voice_pt)
# v1.0 model for other languages
# model_1_0 = KModel(repo_id=repo_id_1_0, model=os.path.join(model_1_0_dir, 'kokoro-v1_0.pth'), config=os.path.join(model_1_0_dir, 'config.json')).to(device).eval()
# # es
# es_pipeline = KPipeline(lang_code='e', repo_id=repo_id_1_0, model=model_1_0)
# es_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ef_dora.pt')
# pipeline_dict['es'] = LanguagePipeline(pipeline=es_pipeline, voice_pt=es_voice_pt)
# # fr
# fr_pipeline = KPipeline(lang_code='f', repo_id=repo_id_1_0, model=model_1_0)
# fr_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ff_siwis.pt')
# pipeline_dict['fr'] = LanguagePipeline(pipeline=fr_pipeline, voice_pt=fr_voice_pt)
# # hi
# hi_pipeline = KPipeline(lang_code='h', repo_id=repo_id_1_0, model=model_1_0)
# hi_voice_pt = os.path.join(model_1_0_dir, 'voices', 'hf_alpha.pt')
# pipeline_dict['hi'] = LanguagePipeline(pipeline=hi_pipeline, voice_pt=hi_voice_pt)
# # it
# it_pipeline = KPipeline(lang_code='i', repo_id=repo_id_1_0, model=model_1_0)
# it_voice_pt = os.path.join(model_1_0_dir, 'voices', 'if_sara.pt')
# pipeline_dict['it'] = LanguagePipeline(pipeline=it_pipeline, voice_pt=it_voice_pt)
# # ja
# ja_pipeline = KPipeline(lang_code='j', repo_id=repo_id_1_0, model=model_1_0)
# ja_voice_pt = os.path.join(model_1_0_dir, 'voices', 'jf_alpha.pt')
# pipeline_dict['ja'] = LanguagePipeline(pipeline=ja_pipeline, voice_pt=ja_voice_pt)
# # pt
# pt_pipeline = KPipeline(lang_code='p', repo_id=repo_id_1_0, model=model_1_0)
# pt_voice_pt = os.path.join(model_1_0_dir, 'voices', 'pf_dora.pt')
# pipeline_dict['pt'] = LanguagePipeline(pipeline=pt_pipeline, voice_pt=pt_voice_pt)
warmup()
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
def warmup():
zh_pipeline = pipeline_dict['zh'].pipeline
voice = pipeline_dict['zh'].voice_pt
generator = zh_pipeline(text="语音合成测试TTS。", voice=voice, speed=speed_callable)
for _ in generator:
pass
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
def contains_words(text):
return any(char not in symbols for char in text)
def cut_sentences(text) -> list[str]:
text = text.strip()
splits = re.split(r"([.;?!、。?!;])", text)
sentences = []
for i in range(0, len(splits), 2):
if i + 1 < len(splits):
s = splits[i] + splits[i + 1]
else:
s = splits[i]
s = s.strip()
if s:
sentences.append(s)
return sentences
LANGUAGE_ALIASES = {
'z': 'zh',
'a': 'en',
'e': 'es',
'f': 'fr',
'h': 'hi',
'i': 'it',
'j': 'ja',
'p': 'pt',
}
@app.post("/")
@app.post("/tts")
def predict(ssml: str = Body(...), include_header: bool = False):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
transcription = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
if not contains_words(transcription):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
if include_header:
audio_header = wave_header_chunk(sample_rate=TARGET_SR)
audio = audio_header + audio
return Response(audio, media_type='audio/wav')
if language not in pipeline_dict:
if language in LANGUAGE_ALIASES:
language = LANGUAGE_ALIASES[language]
else:
return JSONResponse(status_code=400, content={"message": f"Language '{language}' not supported."})
def streaming_generator():
texts = cut_sentences(transcription)
has_yield = False
for text in texts:
if text.strip() and contains_words(text):
pipeline = pipeline_dict[language].pipeline
voice = pipeline_dict[language].voice_pt
if language == 'zh':
generator = pipeline(text=text, voice=voice, speed=speed_callable)
else:
generator = pipeline(text=text, voice=voice)
for (_, _, audio) in generator:
if include_header and not has_yield:
has_yield = True
yield wave_header_chunk(sample_rate=TARGET_SR)
yield audio_postprocess(audio.numpy(), MODEL_SR, TARGET_SR).tobytes()
return StreamingResponse(streaming_generator(), media_type='audio/wav')
@app.get("/health")
@app.get("/ready")
async def ready():
return JSONResponse(status_code=200, content={"status": "ok"})
@app.get("/health_check")
async def health_check():
try:
a = torch.ones(10, 20, dtype=torch.float32, device='npu')
b = torch.ones(20, 10, dtype=torch.float32, device='npu')
c = torch.matmul(a, b)
if c.sum() == 10 * 20 * 10:
return {"status": "ok"}
else:
raise HTTPException(status_code=503)
except Exception as e:
print(f'health_check failed')
raise HTTPException(status_code=503)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=80)

View File

@@ -0,0 +1,4 @@
#!/bin/bash
python3 kokoro_server.py

View File

@@ -0,0 +1,5 @@
kokoro>=0.8.2
misaki[zh]>=0.8.2
soundfile
fastapi
uvicorn[standard]