diff --git a/bi_v100-matcha/Dockerfile_matcha b/bi_v100-matcha/Dockerfile_matcha new file mode 100644 index 0000000..6915c44 --- /dev/null +++ b/bi_v100-matcha/Dockerfile_matcha @@ -0,0 +1,13 @@ +FROM corex:3.2.1 + +WORKDIR /workspace + +COPY requirements_matcha.txt constraints_matcha.txt matcha_server.py launch_matcha.sh /workspace/ + +RUN pip install -r requirements_matcha.txt -c constraints_matcha.txt +RUN pip install matcha-tts -c constraints_matcha.txt + +RUN apt update \ + && apt install -y espeak-ng + +ENTRYPOINT ["/bin/bash", "launch_matcha.sh"] diff --git a/bi_v100-matcha/README.md b/bi_v100-matcha/README.md new file mode 100644 index 0000000..d7ba481 --- /dev/null +++ b/bi_v100-matcha/README.md @@ -0,0 +1,3 @@ +# tiangai100-matcha-tts + +【语音合成】 diff --git a/bi_v100-matcha/constraints_matcha.txt b/bi_v100-matcha/constraints_matcha.txt new file mode 100644 index 0000000..03d1dda --- /dev/null +++ b/bi_v100-matcha/constraints_matcha.txt @@ -0,0 +1,3 @@ +torch==2.1.0+corex.3.2.1 +numpy==1.23.5 +scipy==1.14.1 diff --git a/bi_v100-matcha/launch_matcha.sh b/bi_v100-matcha/launch_matcha.sh new file mode 100755 index 0000000..c9ec741 --- /dev/null +++ b/bi_v100-matcha/launch_matcha.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +python3 matcha_server.py + diff --git a/bi_v100-matcha/matcha_server.py b/bi_v100-matcha/matcha_server.py new file mode 100644 index 0000000..bcc2318 --- /dev/null +++ b/bi_v100-matcha/matcha_server.py @@ -0,0 +1,197 @@ +import os +model_dir = os.getenv("MODEL_DIR", "/mounted_model") +model_name = os.getenv("MODEL_NAME", "model.ckpt") + +import logging +logging.basicConfig( + format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO"), +) +logger = logging.getLogger(__file__) + +# enable custom patcher if available +patcher_path = os.path.join(model_dir, "custom_patcher.py") +if os.path.exists(patcher_path): + import shutil + shutil.copyfile(patcher_path, "custom_patcher.py") + try: + import custom_patcher + logger.info("Custom patcher has been applied.") + except ImportError: + logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.") +else: + logger.info("No custom_patcher found.") + +import wave +import numpy as np +from scipy.signal import resample +import re + +from fastapi import FastAPI, Response, Body, HTTPException +from fastapi.responses import StreamingResponse, JSONResponse +from contextlib import asynccontextmanager +import uvicorn +import xml.etree.ElementTree as ET + +import torch +torch.set_num_threads(4) + +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) + +from torch import Tensor +from torch.nn import functional as F +from typing import Optional, List + +def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 1 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, self.dilation) # type: ignore[arg-type] + with torch.amp.autocast('cuda', dtype=torch.float16): + return F.conv_transpose1d( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation).float() + +torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward + +from matcha.cli import load_matcha, load_vocoder, to_waveform, process_text + +model = None +vocoder = None +denoiser = None +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +MODEL_SR = int(os.getenv("MODEL_SR", 22050)) +speaking_rate = float(os.getenv("SPEAKING_RATE", 1.0)) +TARGET_SR = 16000 +N_ZEROS = 100 + + +def init(): + global model, vocoder, denoiser + ckpt_path = os.path.join(model_dir, model_name) + vocoder_path = os.path.join(model_dir, "generator_v1") + model = load_matcha("custom_model", ckpt_path, device) + vocoder, denoiser = load_vocoder("hifigan_T2_v1", vocoder_path, device) + + # warmup: + for _ in generate("你好,欢迎使用语音合成服务。"): + pass + +@asynccontextmanager +async def lifespan(app: FastAPI): + init() + yield + pass + +app = FastAPI(lifespan=lifespan) + +xml_namespace = "{http://www.w3.org/XML/1998/namespace}" +symbols = ',.!?;:()[]{}<>,。!?;:【】《》……"“”_—' +def contains_words(text): + return any(char not in symbols for char in text) + +def split_text(text, max_chars=135): + sentences = re.split(r"(?<=[;:.!?])\s+|(?<=[。!?])", text) + sentences = [s.strip() for s in sentences if s.strip()] + return sentences + + # chunks = [] + # current_chunk = "" + # for sentence in sentences: + # if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: + # current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence + # else: + # if current_chunk: + # chunks.append(current_chunk.strip()) + # current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence + + # if current_chunk: + # chunks.append(current_chunk.strip()) + + # return chunks + +def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray: + if ori_sr != target_sr: + number_of_samples = int(len(audio) * float(target_sr) / ori_sr) + audio_resampled = resample(audio, number_of_samples) + else: + audio_resampled = audio + if audio.dtype == np.float32: + audio_resampled = np.clip(audio_resampled, -1.0, 1.0) + audio_resampled = (audio_resampled * 32767).astype(np.int16) + return audio_resampled + +def generate(texts): + chunks = split_text(texts) + for i, chunk in enumerate(chunks): + try: + text_processed = process_text(0, chunk, device) + except Exception as e: + logger.error(f"Error processing text: {e}") + with torch.inference_mode(): + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=10, + temperature=0.667, + spks=None, + length_scale=speaking_rate + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, denoiser_strength=0.00025) + audio = output["waveform"].detach().cpu().squeeze().numpy() + yield audio_postprocess(audio, MODEL_SR, TARGET_SR).tobytes() + + +@app.post("/") +@app.post("/tts") +def predict(ssml: str = Body(...)): + try: + root = ET.fromstring(ssml) + voice_element = root.find(".//voice") + if voice_element is not None: + transcription = voice_element.text.strip() + language = voice_element.get(f'{xml_namespace}lang', "zh").strip() + # voice_name = voice_element.get("name", "zh-f-soft-1").strip() + else: + return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) + except ET.ParseError as e: + return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) + + if not contains_words(transcription): + audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes() + return Response(audio, media_type='audio/wav') + + return StreamingResponse(generate(transcription), media_type='audio/wav') + +@app.get("/health") +@app.get("/ready") +async def ready(): + return JSONResponse(status_code=200, content={"message": "success"}) + +@app.get("/health_check") +async def health_check(): + try: + a = torch.ones(10, 20, dtype=torch.float32, device='cuda') + b = torch.ones(20, 10, dtype=torch.float32, device='cuda') + c = torch.matmul(a, b) + if c.sum() == 10 * 20 * 10: + return {"status": "ok"} + else: + raise HTTPException(status_code=503) + except Exception as e: + print(f'health_check failed') + raise HTTPException(status_code=503) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=80) diff --git a/bi_v100-matcha/requirements_matcha.txt b/bi_v100-matcha/requirements_matcha.txt new file mode 100644 index 0000000..e569a95 --- /dev/null +++ b/bi_v100-matcha/requirements_matcha.txt @@ -0,0 +1,47 @@ +# --------- pytorch --------- # +torch>=2.0.0 +torchvision>=0.15.0 +lightning>=2.0.0 +torchmetrics>=0.11.4 + +# --------- hydra --------- # +hydra-core==1.3.2 +hydra-colorlog==1.2.0 +hydra-optuna-sweeper==1.2.0 + +# --------- loggers --------- # +# wandb +# neptune-client +# mlflow +# comet-ml +# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 + +# --------- others --------- # +rootutils # standardizing the project root setup +pre-commit # hooks for applying linters on commit +rich # beautiful text formatting in terminal +pytest # tests +# sh # for running bash commands in some tests (linux/macos only) +phonemizer # phonemization of text +tensorboard +librosa +Cython +numpy +einops +inflect +Unidecode +scipy +torchaudio +matplotlib +pandas +conformer==0.3.2 +diffusers # developed using version ==0.25.0 +notebook +ipywidgets +gradio==3.43.2 +gdown +wget +seaborn + +fastapi +uvicorn[standard]