init ascend tts
This commit is contained in:
182
ascend_910-matcha/matcha_server.py
Normal file
182
ascend_910-matcha/matcha_server.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import os
|
||||
model_dir = os.getenv("MODEL_DIR", "/mounted_model")
|
||||
model_name = os.getenv("MODEL_NAME", "model.ckpt")
|
||||
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO"),
|
||||
)
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
import wave
|
||||
import numpy as np
|
||||
from scipy.signal import resample
|
||||
import re
|
||||
|
||||
from fastapi import FastAPI, Response, Body, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
import xml.etree.ElementTree as ET
|
||||
import torch
|
||||
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
|
||||
_original_hann_window = torch.hann_window
|
||||
|
||||
def _safe_hann_window(window_length,
|
||||
periodic=True,
|
||||
*,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
device=None,
|
||||
requires_grad=False,
|
||||
**kwargs):
|
||||
"""
|
||||
NPU 不支持int64 hann_window, 替换实现
|
||||
"""
|
||||
if dtype is None:
|
||||
dtype = torch.float32
|
||||
|
||||
# 总是在 CPU 先生成,绕过 NPU 上的 in-place cos 实现
|
||||
win = _original_hann_window(
|
||||
window_length,
|
||||
periodic=periodic,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device="cpu",
|
||||
requires_grad=requires_grad,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
win = win.to(device)
|
||||
|
||||
return win
|
||||
|
||||
torch.hann_window = _safe_hann_window
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from matcha.cli import load_matcha, load_vocoder, to_waveform, process_text
|
||||
|
||||
model = None
|
||||
vocoder = None
|
||||
denoiser = None
|
||||
device = 'npu'
|
||||
|
||||
MODEL_SR = int(os.getenv("MODEL_SR", 22050))
|
||||
speaking_rate = float(os.getenv("SPEAKING_RATE", 1.0))
|
||||
TARGET_SR = 16000
|
||||
N_ZEROS = 100
|
||||
|
||||
|
||||
def init():
|
||||
global model, vocoder, denoiser
|
||||
ckpt_path = os.path.join(model_dir, model_name)
|
||||
vocoder_path = os.path.join(model_dir, "generator_v1")
|
||||
model = load_matcha("custom_model", ckpt_path, device)
|
||||
vocoder, denoiser = load_vocoder("hifigan_T2_v1", vocoder_path, device)
|
||||
|
||||
# warmup:
|
||||
for _ in generate("你好,欢迎使用语音合成服务。"):
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
init()
|
||||
yield
|
||||
pass
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
|
||||
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
|
||||
def contains_words(text):
|
||||
return any(char not in symbols for char in text)
|
||||
|
||||
def split_text(text, max_chars=135):
|
||||
sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text)
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
return sentences
|
||||
|
||||
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
|
||||
if ori_sr != target_sr:
|
||||
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
|
||||
audio_resampled = resample(audio, number_of_samples)
|
||||
else:
|
||||
audio_resampled = audio
|
||||
if audio.dtype == np.float32:
|
||||
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
|
||||
audio_resampled = (audio_resampled * 32767).astype(np.int16)
|
||||
return audio_resampled
|
||||
|
||||
def generate(texts):
|
||||
chunks = split_text(texts)
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
text_processed = process_text(0, chunk, device)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing text: {e}")
|
||||
with torch.inference_mode():
|
||||
output = model.synthesise(
|
||||
text_processed["x"],
|
||||
text_processed["x_lengths"],
|
||||
n_timesteps=10,
|
||||
temperature=0.667,
|
||||
spks=None,
|
||||
length_scale=speaking_rate
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, denoiser_strength=0.00025)
|
||||
audio = output["waveform"].detach().cpu().squeeze().numpy()
|
||||
yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes()
|
||||
|
||||
|
||||
@app.post("/")
|
||||
@app.post("/tts")
|
||||
def predict(ssml: str = Body(...)):
|
||||
try:
|
||||
root = ET.fromstring(ssml)
|
||||
voice_element = root.find(".//voice")
|
||||
if voice_element is not None:
|
||||
transcription = voice_element.text.strip()
|
||||
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
|
||||
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
|
||||
except ET.ParseError as e:
|
||||
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
|
||||
|
||||
if not contains_words(transcription):
|
||||
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
|
||||
return Response(audio, media_type='audio/wav')
|
||||
|
||||
return StreamingResponse(generate(transcription), media_type='audio/wav')
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/ready")
|
||||
async def ready():
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
@app.get("/health_check")
|
||||
async def health_check():
|
||||
try:
|
||||
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
|
||||
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
|
||||
c = torch.matmul(a, b)
|
||||
if c.sum() == 10 * 20 * 10:
|
||||
return {"status": "ok"}
|
||||
else:
|
||||
raise HTTPException(status_code=503)
|
||||
except Exception as e:
|
||||
print(f'health_check failed')
|
||||
raise HTTPException(status_code=503)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=80)
|
||||
Reference in New Issue
Block a user