init muxi

This commit is contained in:
2025-09-12 11:39:55 +08:00
commit 96ef2da601
602 changed files with 591073 additions and 0 deletions

View File

@@ -0,0 +1,311 @@
import os
import io
from fastapi import FastAPI, Response, Body, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import xml.etree.ElementTree as ET
from kokoro import KPipeline, KModel
# import soundfile as sf
import wave
import numpy as np
from scipy.signal import resample
import torch
from torch import Tensor
from torch.nn import functional as F
from typing import Optional, List
import re
from dataclasses import dataclass
repo_id = 'hexgrad/Kokoro-82M-v1.1-zh'
MODEL_SR = 24000
TARGET_SR = 16000
# How much silence to insert between paragraphs: 5000 is about 0.2 seconds
N_ZEROS = 20
model = None
en_empty_pipeline = None
en_voice = os.getenv('EN_VOICE', 'af_maple.pt')
zh_voice = os.getenv('ZH_VOICE', 'zf_046.pt')
model_dir = os.getenv('MODEL_DIR', '/model/hexgrad')
model_name = os.getenv('MODEL_NAME','kokoro-v1_1-zh.pth')
# model_1_1_dir = os.path.join(model_dir, 'Kokoro-82M-v1.1-zh')
# model_1_0_dir = os.path.join(model_dir, 'Kokoro-82M')
# repo_id_1_0 = 'hexgrad/Kokoro-82M'
@dataclass
class LanguagePipeline:
pipeline: KPipeline
voice_pt: str
pipeline_dict: dict[str, LanguagePipeline] = {}
def en_callable(text):
if text == 'Kokoro':
return 'kˈOkəɹO'
elif text == 'Sol':
return 'sˈOl'
return next(en_empty_pipeline(text)).phonemes
# HACK: Mitigate rushing caused by lack of training data beyond ~100 tokens
# Simple piecewise linear fn that decreases speed as len_ps increases
def speed_callable(len_ps):
speed = 0.8
if len_ps <= 83:
speed = 1
elif len_ps < 183:
speed = 1 - (len_ps - 83) / 500
return speed
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
# This will create a wave header then append the frame input
# It should be first on a streaming wav file
# Other frames better should not have it (else you will hear some artifacts each chunk start)
wav_buf = io.BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()
def resample_audio(data: np.ndarray, original_rate: int, target_rate: int):
ori_dtype = data.dtype
# data = normalize_audio(data)
number_of_samples = int(len(data) * float(target_rate) / original_rate)
resampled_data = resample(data, number_of_samples)
# resampled_data = normalize_audio(resampled_data)
return resampled_data.astype(ori_dtype)
def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int):
audio = resample_audio(data, original_rate, target_rate)
if audio.dtype == np.float32:
audio = np.int16(audio * 32767)
audio = np.concatenate([audio, np.zeros(N_ZEROS, dtype=np.int16)])
return audio
# ================== decoder/istftnet 补丁 ==================
def _to_cpu_fp32(obj):
if torch.is_tensor(obj):
return obj.detach().to("cpu", dtype=torch.float32)
if isinstance(obj, (list, tuple)):
return type(obj)(_to_cpu_fp32(x) for x in obj)
if isinstance(obj, dict):
return {k: _to_cpu_fp32(v) for k, v in obj.items()}
return obj
def patch_decoder(model, device: str):
decoder = getattr(model, "decoder", None)
if decoder is None:
raise RuntimeError("未找到 model.decoder请 print(model) 确认实际模块名。")
decoder.eval().to(device).float()
try: torch.nn.utils.remove_weight_norm(decoder)
except Exception: pass
for p in decoder.parameters():
p.requires_grad = False
if p.dtype != torch.float32: p.data = p.data.float()
for n, b in decoder.named_buffers():
if b.dtype != torch.float32: setattr(decoder, n, b.float())
orig_forward = decoder.forward
def forward_patched(*args, **kwargs):
# 关键decoder 运行时,强制 FP32、并保证输入与 decoder 在同一设备
with torch.amp.autocast('cuda', enabled=False):
if device == "cpu":
args_ = _to_cpu_fp32(args); kwargs_ = _to_cpu_fp32(kwargs)
out = orig_forward(*args_, **kwargs_)
return _to_cpu_fp32(out)
else: # device == "cuda"
def to_gpu_fp32(x):
if torch.is_tensor(x):
return x.detach().to("cuda", dtype=torch.float32)
if isinstance(x, (list, tuple)):
return type(x)(to_gpu_fp32(t) for t in x)
if isinstance(x, dict):
return {k: to_gpu_fp32(v) for k, v in x.items()}
return x
args_ = to_gpu_fp32(args); kwargs_ = to_gpu_fp32(kwargs)
out = orig_forward(*args_, **kwargs_)
if torch.is_tensor(out):
return out.detach().to(device, dtype=torch.float32)
return out
decoder.forward = forward_patched
def init():
global model, en_empty_pipeline
global model_1_0
global pipeline_dict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval()
patch_decoder(model, "cpu")
en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False)
en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)
zh_pipeline = KPipeline(lang_code='z', repo_id=repo_id, model=model, en_callable=en_callable)
en_voice_pt = os.path.join(model_dir, 'voices', en_voice)
zh_voice_pt = os.path.join(model_dir, 'voices', zh_voice)
pipeline_dict['zh'] = LanguagePipeline(pipeline=zh_pipeline, voice_pt=zh_voice_pt)
pipeline_dict['en'] = LanguagePipeline(pipeline=en_pipeline, voice_pt=en_voice_pt)
# v1.0 model for other languages
# model_1_0 = KModel(repo_id=repo_id_1_0, model=os.path.join(model_1_0_dir, 'kokoro-v1_0.pth'), config=os.path.join(model_1_0_dir, 'config.json')).to(device).eval()
# # es
# es_pipeline = KPipeline(lang_code='e', repo_id=repo_id_1_0, model=model_1_0)
# es_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ef_dora.pt')
# pipeline_dict['es'] = LanguagePipeline(pipeline=es_pipeline, voice_pt=es_voice_pt)
# # fr
# fr_pipeline = KPipeline(lang_code='f', repo_id=repo_id_1_0, model=model_1_0)
# fr_voice_pt = os.path.join(model_1_0_dir, 'voices', 'ff_siwis.pt')
# pipeline_dict['fr'] = LanguagePipeline(pipeline=fr_pipeline, voice_pt=fr_voice_pt)
# # hi
# hi_pipeline = KPipeline(lang_code='h', repo_id=repo_id_1_0, model=model_1_0)
# hi_voice_pt = os.path.join(model_1_0_dir, 'voices', 'hf_alpha.pt')
# pipeline_dict['hi'] = LanguagePipeline(pipeline=hi_pipeline, voice_pt=hi_voice_pt)
# # it
# it_pipeline = KPipeline(lang_code='i', repo_id=repo_id_1_0, model=model_1_0)
# it_voice_pt = os.path.join(model_1_0_dir, 'voices', 'if_sara.pt')
# pipeline_dict['it'] = LanguagePipeline(pipeline=it_pipeline, voice_pt=it_voice_pt)
# # ja
# ja_pipeline = KPipeline(lang_code='j', repo_id=repo_id_1_0, model=model_1_0)
# ja_voice_pt = os.path.join(model_1_0_dir, 'voices', 'jf_alpha.pt')
# pipeline_dict['ja'] = LanguagePipeline(pipeline=ja_pipeline, voice_pt=ja_voice_pt)
# # pt
# pt_pipeline = KPipeline(lang_code='p', repo_id=repo_id_1_0, model=model_1_0)
# pt_voice_pt = os.path.join(model_1_0_dir, 'voices', 'pf_dora.pt')
# pipeline_dict['pt'] = LanguagePipeline(pipeline=pt_pipeline, voice_pt=pt_voice_pt)
warmup()
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
def warmup():
zh_pipeline = pipeline_dict['zh'].pipeline
voice = pipeline_dict['zh'].voice_pt
generator = zh_pipeline(text="语音合成测试TTS。", voice=voice, speed=speed_callable)
for _ in generator:
pass
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
def contains_words(text):
return any(char not in symbols for char in text)
def cut_sentences(text) -> list[str]:
text = text.strip()
splits = re.split(r"([.;?!、。?!;])", text)
sentences = []
for i in range(0, len(splits), 2):
if i + 1 < len(splits):
s = splits[i] + splits[i + 1]
else:
s = splits[i]
s = s.strip()
if s:
sentences.append(s)
return sentences
LANGUAGE_ALIASES = {
'z': 'zh',
'a': 'en',
'e': 'es',
'f': 'fr',
'h': 'hi',
'i': 'it',
'j': 'ja',
'p': 'pt',
}
@app.post("/")
@app.post("/tts")
def predict(ssml: str = Body(...), include_header: bool = False):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
transcription = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
if not contains_words(transcription):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
if include_header:
audio_header = wave_header_chunk(sample_rate=TARGET_SR)
audio = audio_header + audio
return Response(audio, media_type='audio/wav')
if language not in pipeline_dict:
if language in LANGUAGE_ALIASES:
language = LANGUAGE_ALIASES[language]
else:
return JSONResponse(status_code=400, content={"message": f"Language '{language}' not supported."})
def streaming_generator():
texts = cut_sentences(transcription)
has_yield = False
for text in texts:
if text.strip() and contains_words(text):
pipeline = pipeline_dict[language].pipeline
voice = pipeline_dict[language].voice_pt
if language == 'zh':
generator = pipeline(text=text, voice=voice, speed=speed_callable)
else:
generator = pipeline(text=text, voice=voice)
for (_, _, audio) in generator:
if include_header and not has_yield:
has_yield = True
yield wave_header_chunk(sample_rate=TARGET_SR)
yield audio_postprocess(audio.numpy(), MODEL_SR, TARGET_SR).tobytes()
return StreamingResponse(streaming_generator(), media_type='audio/wav')
@app.get("/health")
@app.get("/ready")
async def ready():
return JSONResponse(status_code=200, content={"status": "ok"})
@app.get("/health_check")
async def health_check():
try:
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
c = torch.matmul(a, b)
if c.sum() == 10 * 20 * 10:
return {"status": "ok"}
else:
raise HTTPException(status_code=503)
except Exception as e:
print(f'health_check failed')
raise HTTPException(status_code=503)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=80)