init matcha
This commit is contained in:
13
bi_v100-matcha/Dockerfile_matcha
Normal file
13
bi_v100-matcha/Dockerfile_matcha
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
FROM corex:3.2.1
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
COPY requirements_matcha.txt constraints_matcha.txt matcha_server.py launch_matcha.sh /workspace/
|
||||||
|
|
||||||
|
RUN pip install -r requirements_matcha.txt -c constraints_matcha.txt
|
||||||
|
RUN pip install matcha-tts -c constraints_matcha.txt
|
||||||
|
|
||||||
|
RUN apt update \
|
||||||
|
&& apt install -y espeak-ng
|
||||||
|
|
||||||
|
ENTRYPOINT ["/bin/bash", "launch_matcha.sh"]
|
||||||
3
bi_v100-matcha/README.md
Normal file
3
bi_v100-matcha/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# tiangai100-matcha-tts
|
||||||
|
|
||||||
|
【语音合成】
|
||||||
3
bi_v100-matcha/constraints_matcha.txt
Normal file
3
bi_v100-matcha/constraints_matcha.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
torch==2.1.0+corex.3.2.1
|
||||||
|
numpy==1.23.5
|
||||||
|
scipy==1.14.1
|
||||||
4
bi_v100-matcha/launch_matcha.sh
Executable file
4
bi_v100-matcha/launch_matcha.sh
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python3 matcha_server.py
|
||||||
|
|
||||||
197
bi_v100-matcha/matcha_server.py
Normal file
197
bi_v100-matcha/matcha_server.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
import os
|
||||||
|
model_dir = os.getenv("MODEL_DIR", "/mounted_model")
|
||||||
|
model_name = os.getenv("MODEL_NAME", "model.ckpt")
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", "INFO"),
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
# enable custom patcher if available
|
||||||
|
patcher_path = os.path.join(model_dir, "custom_patcher.py")
|
||||||
|
if os.path.exists(patcher_path):
|
||||||
|
import shutil
|
||||||
|
shutil.copyfile(patcher_path, "custom_patcher.py")
|
||||||
|
try:
|
||||||
|
import custom_patcher
|
||||||
|
logger.info("Custom patcher has been applied.")
|
||||||
|
except ImportError:
|
||||||
|
logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.")
|
||||||
|
else:
|
||||||
|
logger.info("No custom_patcher found.")
|
||||||
|
|
||||||
|
import wave
|
||||||
|
import numpy as np
|
||||||
|
from scipy.signal import resample
|
||||||
|
import re
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Response, Body, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import uvicorn
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
|
torch.backends.cuda.enable_flash_sdp(False)
|
||||||
|
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||||
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||||
|
if self.padding_mode != 'zeros':
|
||||||
|
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
|
||||||
|
|
||||||
|
assert isinstance(self.padding, tuple)
|
||||||
|
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||||
|
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||||
|
num_spatial_dims = 1
|
||||||
|
output_padding = self._output_padding(
|
||||||
|
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
|
||||||
|
num_spatial_dims, self.dilation) # type: ignore[arg-type]
|
||||||
|
with torch.amp.autocast('cuda', dtype=torch.float16):
|
||||||
|
return F.conv_transpose1d(
|
||||||
|
input, self.weight, self.bias, self.stride, self.padding,
|
||||||
|
output_padding, self.groups, self.dilation).float()
|
||||||
|
|
||||||
|
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
|
||||||
|
|
||||||
|
from matcha.cli import load_matcha, load_vocoder, to_waveform, process_text
|
||||||
|
|
||||||
|
model = None
|
||||||
|
vocoder = None
|
||||||
|
denoiser = None
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
MODEL_SR = int(os.getenv("MODEL_SR", 22050))
|
||||||
|
speaking_rate = float(os.getenv("SPEAKING_RATE", 1.0))
|
||||||
|
TARGET_SR = 16000
|
||||||
|
N_ZEROS = 100
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
global model, vocoder, denoiser
|
||||||
|
ckpt_path = os.path.join(model_dir, model_name)
|
||||||
|
vocoder_path = os.path.join(model_dir, "generator_v1")
|
||||||
|
model = load_matcha("custom_model", ckpt_path, device)
|
||||||
|
vocoder, denoiser = load_vocoder("hifigan_T2_v1", vocoder_path, device)
|
||||||
|
|
||||||
|
# warmup:
|
||||||
|
for _ in generate("你好,欢迎使用语音合成服务。"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
init()
|
||||||
|
yield
|
||||||
|
pass
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
|
||||||
|
symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—'
|
||||||
|
def contains_words(text):
|
||||||
|
return any(char not in symbols for char in text)
|
||||||
|
|
||||||
|
def split_text(text, max_chars=135):
|
||||||
|
sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text)
|
||||||
|
sentences = [s.strip() for s in sentences if s.strip()]
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
# chunks = []
|
||||||
|
# current_chunk = ""
|
||||||
|
# for sentence in sentences:
|
||||||
|
# if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
||||||
|
# current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
||||||
|
# else:
|
||||||
|
# if current_chunk:
|
||||||
|
# chunks.append(current_chunk.strip())
|
||||||
|
# current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
||||||
|
|
||||||
|
# if current_chunk:
|
||||||
|
# chunks.append(current_chunk.strip())
|
||||||
|
|
||||||
|
# return chunks
|
||||||
|
|
||||||
|
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
|
||||||
|
if ori_sr != target_sr:
|
||||||
|
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
|
||||||
|
audio_resampled = resample(audio, number_of_samples)
|
||||||
|
else:
|
||||||
|
audio_resampled = audio
|
||||||
|
if audio.dtype == np.float32:
|
||||||
|
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
|
||||||
|
audio_resampled = (audio_resampled * 32767).astype(np.int16)
|
||||||
|
return audio_resampled
|
||||||
|
|
||||||
|
def generate(texts):
|
||||||
|
chunks = split_text(texts)
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
try:
|
||||||
|
text_processed = process_text(0, chunk, device)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing text: {e}")
|
||||||
|
with torch.inference_mode():
|
||||||
|
output = model.synthesise(
|
||||||
|
text_processed["x"],
|
||||||
|
text_processed["x_lengths"],
|
||||||
|
n_timesteps=10,
|
||||||
|
temperature=0.667,
|
||||||
|
spks=None,
|
||||||
|
length_scale=speaking_rate
|
||||||
|
)
|
||||||
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, denoiser_strength=0.00025)
|
||||||
|
audio = output["waveform"].detach().cpu().squeeze().numpy()
|
||||||
|
yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/")
|
||||||
|
@app.post("/tts")
|
||||||
|
def predict(ssml: str = Body(...)):
|
||||||
|
try:
|
||||||
|
root = ET.fromstring(ssml)
|
||||||
|
voice_element = root.find(".//voice")
|
||||||
|
if voice_element is not None:
|
||||||
|
transcription = voice_element.text.strip()
|
||||||
|
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
|
||||||
|
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
|
||||||
|
else:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
|
||||||
|
except ET.ParseError as e:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
|
||||||
|
|
||||||
|
if not contains_words(transcription):
|
||||||
|
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
|
||||||
|
return Response(audio, media_type='audio/wav')
|
||||||
|
|
||||||
|
return StreamingResponse(generate(transcription), media_type='audio/wav')
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
@app.get("/ready")
|
||||||
|
async def ready():
|
||||||
|
return JSONResponse(status_code=200, content={"message": "success"})
|
||||||
|
|
||||||
|
@app.get("/health_check")
|
||||||
|
async def health_check():
|
||||||
|
try:
|
||||||
|
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
|
||||||
|
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
|
||||||
|
c = torch.matmul(a, b)
|
||||||
|
if c.sum() == 10 * 20 * 10:
|
||||||
|
return {"status": "ok"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=503)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'health_check failed')
|
||||||
|
raise HTTPException(status_code=503)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=80)
|
||||||
47
bi_v100-matcha/requirements_matcha.txt
Normal file
47
bi_v100-matcha/requirements_matcha.txt
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# --------- pytorch --------- #
|
||||||
|
torch>=2.0.0
|
||||||
|
torchvision>=0.15.0
|
||||||
|
lightning>=2.0.0
|
||||||
|
torchmetrics>=0.11.4
|
||||||
|
|
||||||
|
# --------- hydra --------- #
|
||||||
|
hydra-core==1.3.2
|
||||||
|
hydra-colorlog==1.2.0
|
||||||
|
hydra-optuna-sweeper==1.2.0
|
||||||
|
|
||||||
|
# --------- loggers --------- #
|
||||||
|
# wandb
|
||||||
|
# neptune-client
|
||||||
|
# mlflow
|
||||||
|
# comet-ml
|
||||||
|
# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
|
||||||
|
|
||||||
|
# --------- others --------- #
|
||||||
|
rootutils # standardizing the project root setup
|
||||||
|
pre-commit # hooks for applying linters on commit
|
||||||
|
rich # beautiful text formatting in terminal
|
||||||
|
pytest # tests
|
||||||
|
# sh # for running bash commands in some tests (linux/macos only)
|
||||||
|
phonemizer # phonemization of text
|
||||||
|
tensorboard
|
||||||
|
librosa
|
||||||
|
Cython
|
||||||
|
numpy
|
||||||
|
einops
|
||||||
|
inflect
|
||||||
|
Unidecode
|
||||||
|
scipy
|
||||||
|
torchaudio
|
||||||
|
matplotlib
|
||||||
|
pandas
|
||||||
|
conformer==0.3.2
|
||||||
|
diffusers # developed using version ==0.25.0
|
||||||
|
notebook
|
||||||
|
ipywidgets
|
||||||
|
gradio==3.43.2
|
||||||
|
gdown
|
||||||
|
wget
|
||||||
|
seaborn
|
||||||
|
|
||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
Reference in New Issue
Block a user