2025-08-20 16:10:23 +08:00
|
|
|
|
import os
|
|
|
|
|
|
model_dir = os.getenv("MODEL_DIR", "/mounted_model")
|
|
|
|
|
|
model_name = os.getenv("MODEL_NAME", "model.safetensors")
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
|
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
|
|
|
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
|
|
|
|
level=os.environ.get("LOGLEVEL", "INFO"),
|
|
|
|
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
# enable custom patcher if available
|
|
|
|
|
|
patcher_path = os.path.join(model_dir, "custom_patcher.py")
|
|
|
|
|
|
if os.path.exists(patcher_path):
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
shutil.copyfile(patcher_path, "custom_patcher.py")
|
|
|
|
|
|
try:
|
|
|
|
|
|
import custom_patcher
|
|
|
|
|
|
logger.info("Custom patcher has been applied.")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.info("No custom_patcher found.")
|
|
|
|
|
|
|
2025-08-12 14:15:41 +08:00
|
|
|
|
import torch
|
2025-08-20 16:10:23 +08:00
|
|
|
|
torch.set_num_threads(4)
|
2025-08-12 14:15:41 +08:00
|
|
|
|
|
|
|
|
|
|
torch.backends.cuda.enable_flash_sdp(False)
|
|
|
|
|
|
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
|
|
|
|
|
torch.backends.cuda.enable_math_sdp(True)
|
|
|
|
|
|
|
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
from typing import Optional, List
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
# def custom_conv1d_forward(self, input: Tensor, debug=False) -> Tensor:
|
|
|
|
|
|
# with torch.amp.autocast(input.device.type, dtype=torch.float):
|
|
|
|
|
|
# return self._conv_forward(input, self.weight, self.bias)
|
|
|
|
|
|
|
|
|
|
|
|
# torch.nn.Conv1d.forward = custom_conv1d_forward
|
|
|
|
|
|
|
|
|
|
|
|
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
|
|
|
|
|
if self.padding_mode != 'zeros':
|
|
|
|
|
|
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(self.padding, tuple)
|
|
|
|
|
|
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
|
|
|
|
|
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
|
|
|
|
|
num_spatial_dims = 1
|
|
|
|
|
|
output_padding = self._output_padding(
|
|
|
|
|
|
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
|
|
|
|
|
|
num_spatial_dims, self.dilation) # type: ignore[arg-type]
|
|
|
|
|
|
with torch.amp.autocast('cuda', dtype=torch.float16):
|
|
|
|
|
|
return F.conv_transpose1d(
|
|
|
|
|
|
input, self.weight, self.bias, self.stride, self.padding,
|
|
|
|
|
|
output_padding, self.groups, self.dilation).float()
|
|
|
|
|
|
|
|
|
|
|
|
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
|
|
|
|
|
|
|
|
|
|
|
|
from f5_tts.infer.utils_infer import (
|
|
|
|
|
|
load_vocoder,
|
|
|
|
|
|
load_model,
|
2025-08-20 16:10:23 +08:00
|
|
|
|
preprocess_ref_audio_text,
|
|
|
|
|
|
infer_process,
|
2025-08-12 14:15:41 +08:00
|
|
|
|
infer_batch_process,
|
|
|
|
|
|
)
|
|
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
from hydra.utils import get_class
|
2025-08-20 16:10:23 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
import re
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import soundfile as sf
|
2025-08-12 14:15:41 +08:00
|
|
|
|
import torchaudio
|
2025-08-20 16:10:23 +08:00
|
|
|
|
from scipy import signal
|
2025-08-12 14:15:41 +08:00
|
|
|
|
import io
|
2025-08-20 16:10:23 +08:00
|
|
|
|
import time
|
2025-08-12 14:15:41 +08:00
|
|
|
|
|
2025-08-20 16:10:23 +08:00
|
|
|
|
from fastapi import FastAPI, Request, Response, Body, HTTPException
|
2025-08-12 14:15:41 +08:00
|
|
|
|
from fastapi import UploadFile, File, Form
|
|
|
|
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
import uvicorn
|
|
|
|
|
|
import os
|
2025-08-20 16:10:23 +08:00
|
|
|
|
import hashlib
|
|
|
|
|
|
import xml.etree.ElementTree as ET
|
|
|
|
|
|
from typing import Union
|
2025-08-12 14:15:41 +08:00
|
|
|
|
|
2025-08-20 16:10:23 +08:00
|
|
|
|
vocoder_dir = os.getenv('VOCODER_DIR', '/workspace/charactr/vocos-mel-24khz')
|
|
|
|
|
|
speed = float(os.getenv('SPEED', 1.0))
|
2025-08-12 14:15:41 +08:00
|
|
|
|
|
|
|
|
|
|
ema_model = None
|
|
|
|
|
|
vocoder = None
|
2025-08-20 16:10:23 +08:00
|
|
|
|
voice_dict = {}
|
2025-08-12 14:15:41 +08:00
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
2025-08-20 16:10:23 +08:00
|
|
|
|
TARGET_SR = 16000
|
|
|
|
|
|
N_ZEROS = 20
|
|
|
|
|
|
|
|
|
|
|
|
# std_ref_audio_file = os.path.join(model_dir, 'ref_audio.wav')
|
|
|
|
|
|
# std_ref_text_file = os.path.join(model_dir, 'ref_text.txt')
|
|
|
|
|
|
std_ref_audio_file = '/workspace/ref_audio.wav'
|
|
|
|
|
|
std_ref_text_file = '/workspace/ref_text.txt'
|
|
|
|
|
|
std_ref_audio = None
|
|
|
|
|
|
std_ref_text = None
|
2025-08-12 14:15:41 +08:00
|
|
|
|
|
|
|
|
|
|
def init():
|
|
|
|
|
|
global ema_model, vocoder
|
2025-08-20 16:10:23 +08:00
|
|
|
|
global std_ref_audio, std_ref_text
|
|
|
|
|
|
logger.info(f'{device=}')
|
2025-08-12 14:15:41 +08:00
|
|
|
|
# load vocoder
|
|
|
|
|
|
vocoder_name = 'vocos'
|
|
|
|
|
|
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=vocoder_dir, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
# load TTS model
|
|
|
|
|
|
model_cfg = OmegaConf.load('/workspace/F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml')
|
|
|
|
|
|
model_cls = get_class(f'f5_tts.model.{model_cfg.model.backbone}')
|
|
|
|
|
|
model_arc = model_cfg.model.arch
|
2025-08-20 16:10:23 +08:00
|
|
|
|
ckpt_file = os.path.join(model_dir, model_name)
|
|
|
|
|
|
vocab_file = os.path.join(model_dir, 'vocab.txt')
|
2025-08-12 14:15:41 +08:00
|
|
|
|
ema_model = load_model(
|
|
|
|
|
|
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-08-20 16:10:23 +08:00
|
|
|
|
with open(std_ref_audio_file, 'rb') as f:
|
|
|
|
|
|
std_ref_audio = f.read()
|
|
|
|
|
|
with open(std_ref_text_file, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
std_ref_text = f.read().strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-08-12 14:15:41 +08:00
|
|
|
|
@asynccontextmanager
|
|
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
|
init()
|
|
|
|
|
|
yield
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
|
2025-08-20 16:10:23 +08:00
|
|
|
|
@app.get("/health")
|
|
|
|
|
|
@app.get("/ready")
|
|
|
|
|
|
async def ready():
|
|
|
|
|
|
return JSONResponse(status_code=200, content={"message": "success"})
|
|
|
|
|
|
|
|
|
|
|
|
def encode_audio_key(audio_bytes: bytes) -> str:
|
|
|
|
|
|
return hashlib.md5(audio_bytes[:16000]).hexdigest()[:16]
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/register_voice")
|
|
|
|
|
|
async def register_voice(
|
|
|
|
|
|
audio: UploadFile = File(...),
|
|
|
|
|
|
text: str = Form(...)
|
|
|
|
|
|
):
|
|
|
|
|
|
global voice_dict
|
|
|
|
|
|
|
|
|
|
|
|
audio_bytes = await audio.read()
|
|
|
|
|
|
audio_key = encode_audio_key(audio_bytes)
|
|
|
|
|
|
# Ensure ref_text ends with a proper sentence-ending punctuation
|
|
|
|
|
|
if not text.endswith(". ") and not text.endswith("。"):
|
|
|
|
|
|
if text.endswith("."):
|
|
|
|
|
|
text += " "
|
|
|
|
|
|
else:
|
|
|
|
|
|
text += ". "
|
|
|
|
|
|
voice_dict[audio_key] = {
|
|
|
|
|
|
'ref_audio': audio_bytes,
|
|
|
|
|
|
'ref_text': text.strip()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# warmup
|
|
|
|
|
|
for _ in generate("流式语音合成,合成测试", audio_key, fast_infer=2):
|
|
|
|
|
|
logger.info("Warming up")
|
|
|
|
|
|
|
|
|
|
|
|
response = {
|
|
|
|
|
|
"status": "success",
|
|
|
|
|
|
"audio_key": audio_key
|
|
|
|
|
|
}
|
|
|
|
|
|
return JSONResponse(status_code=200, content=response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
symbols = """,.!?;:()[]{}<>,。!?;:【】《》……'"’“”_—"""
|
|
|
|
|
|
def contains_words(text):
|
|
|
|
|
|
return any(char not in symbols for char in text)
|
|
|
|
|
|
|
|
|
|
|
|
def split_text(text, max_chars=135, cut_short_first=False):
|
|
|
|
|
|
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
|
|
|
|
|
sentences = [s.strip() for s in sentences if s.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
chunks = []
|
|
|
|
|
|
current_chunk = ""
|
|
|
|
|
|
for sentence in sentences:
|
|
|
|
|
|
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
|
|
|
|
|
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
|
|
|
|
|
else:
|
|
|
|
|
|
if current_chunk and contains_words(current_chunk):
|
|
|
|
|
|
chunks.append(current_chunk.strip())
|
|
|
|
|
|
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
|
|
|
|
|
|
|
|
|
|
|
if current_chunk and contains_words(current_chunk):
|
|
|
|
|
|
chunks.append(current_chunk.strip())
|
|
|
|
|
|
|
|
|
|
|
|
if cut_short_first:
|
|
|
|
|
|
first_sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", chunks[0])
|
|
|
|
|
|
first = first_sentences[0].strip()
|
|
|
|
|
|
rest = "".join(first_sentences[1:]).strip()
|
|
|
|
|
|
first_chunk = [first, rest] if rest else [first]
|
|
|
|
|
|
final_chunks = first_chunk + chunks[1:]
|
|
|
|
|
|
else:
|
|
|
|
|
|
final_chunks = chunks
|
|
|
|
|
|
|
|
|
|
|
|
return final_chunks
|
|
|
|
|
|
|
|
|
|
|
|
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
|
|
|
|
|
|
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
|
|
|
|
|
|
audio_resampled = signal.resample(audio, number_of_samples)
|
|
|
|
|
|
if audio.dtype == np.float32:
|
|
|
|
|
|
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
|
|
|
|
|
|
audio_resampled = (audio_resampled * 32767).astype(np.int16)
|
|
|
|
|
|
return audio_resampled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(gen_text, ref_audio_key, fast_infer=0):
|
|
|
|
|
|
global voice_dict, ema_model, vocoder
|
|
|
|
|
|
ref_audio_ = voice_dict[ref_audio_key]['ref_audio']
|
|
|
|
|
|
ref_text_ = voice_dict[ref_audio_key]['ref_text']
|
|
|
|
|
|
|
|
|
|
|
|
nfe_step = 16
|
|
|
|
|
|
if fast_infer >= 1:
|
|
|
|
|
|
nfe_step = 7
|
|
|
|
|
|
# nonuniform_step = True
|
|
|
|
|
|
# if fast_infer >= 2:
|
|
|
|
|
|
# ref_audio_ = voice_dict[ref_audio_key].get('ref_audio_slice', ref_audio_)
|
|
|
|
|
|
# ref_text_ = voice_dict[ref_audio_key].get('ref_text_slice', ref_text_)
|
|
|
|
|
|
|
|
|
|
|
|
audio, sr = torchaudio.load(io.BytesIO(ref_audio_))
|
|
|
|
|
|
max_chars = int(len(ref_text_.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
|
|
|
|
|
|
gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0))
|
|
|
|
|
|
for gen_audio, gen_sr in infer_batch_process(
|
|
|
|
|
|
(audio, sr),
|
|
|
|
|
|
ref_text_,
|
|
|
|
|
|
gen_text_batches,
|
|
|
|
|
|
ema_model,
|
|
|
|
|
|
vocoder,
|
|
|
|
|
|
device=device,
|
|
|
|
|
|
streaming=True,
|
|
|
|
|
|
chunk_size=int(24e6),
|
|
|
|
|
|
nfe_step=nfe_step,
|
|
|
|
|
|
speed=speed,
|
|
|
|
|
|
):
|
|
|
|
|
|
yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes()
|
|
|
|
|
|
|
|
|
|
|
|
def generate_with_audio(gen_text, ref_audio, ref_text, fast_infer=0):
|
2025-08-12 14:15:41 +08:00
|
|
|
|
global ema_model, vocoder
|
|
|
|
|
|
|
2025-08-20 16:10:23 +08:00
|
|
|
|
if not contains_words(gen_text):
|
|
|
|
|
|
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
|
|
|
|
|
|
yield audio
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
nfe_step = 16
|
|
|
|
|
|
if fast_infer >= 1:
|
|
|
|
|
|
nfe_step = 7
|
|
|
|
|
|
|
2025-08-12 14:15:41 +08:00
|
|
|
|
audio, sr = torchaudio.load(io.BytesIO(ref_audio))
|
|
|
|
|
|
max_chars = min(int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)), 135)
|
2025-08-20 16:10:23 +08:00
|
|
|
|
gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0))
|
2025-08-12 14:15:41 +08:00
|
|
|
|
for gen_audio, gen_sr in infer_batch_process(
|
|
|
|
|
|
(audio, sr),
|
|
|
|
|
|
ref_text,
|
|
|
|
|
|
gen_text_batches,
|
|
|
|
|
|
ema_model,
|
|
|
|
|
|
vocoder,
|
|
|
|
|
|
device=device,
|
|
|
|
|
|
streaming=True,
|
|
|
|
|
|
chunk_size=int(24e6),
|
2025-08-20 16:10:23 +08:00
|
|
|
|
nfe_step=nfe_step,
|
|
|
|
|
|
speed=speed,
|
2025-08-12 14:15:41 +08:00
|
|
|
|
):
|
2025-08-20 16:10:23 +08:00
|
|
|
|
yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/synthesize")
|
|
|
|
|
|
async def synthesize(request: Request):
|
|
|
|
|
|
data = await request.json()
|
|
|
|
|
|
text = data['text']
|
|
|
|
|
|
audio_key = data['audio_key']
|
|
|
|
|
|
fast_infer = data.get('fast_infer', 0)
|
|
|
|
|
|
if fast_infer == True:
|
|
|
|
|
|
fast_infer = 2
|
|
|
|
|
|
else:
|
|
|
|
|
|
fast_infer = int(fast_infer)
|
|
|
|
|
|
# logger.info(f"Synthesizing text: {text}, audio_key: {audio_key}, fast_infer: {fast_infer}")
|
|
|
|
|
|
|
|
|
|
|
|
if not contains_words(text):
|
|
|
|
|
|
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
|
|
|
|
|
|
return Response(audio, media_type='audio/wav')
|
|
|
|
|
|
|
|
|
|
|
|
global voice_dict
|
|
|
|
|
|
if audio_key not in voice_dict:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid audio key")
|
|
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(generate(text, audio_key, fast_infer), media_type="audio/wav")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
|
|
|
|
|
|
@app.post("/tts")
|
|
|
|
|
|
def predict(ssml: str = Body(...), fast_infer: Union[bool, int] = 0):
|
|
|
|
|
|
try:
|
|
|
|
|
|
root = ET.fromstring(ssml)
|
|
|
|
|
|
voice_element = root.find(".//voice")
|
|
|
|
|
|
if voice_element is not None:
|
|
|
|
|
|
transcription = voice_element.text.strip()
|
|
|
|
|
|
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
|
|
|
|
|
|
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
|
|
|
|
|
|
else:
|
|
|
|
|
|
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
|
|
|
|
|
|
except ET.ParseError as e:
|
|
|
|
|
|
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
|
|
|
|
|
|
|
|
|
|
|
|
fast_infer = int(fast_infer)
|
2025-08-12 14:15:41 +08:00
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(
|
2025-08-20 16:10:23 +08:00
|
|
|
|
generate_with_audio(transcription, std_ref_audio, std_ref_text, fast_infer),
|
2025-08-12 14:15:41 +08:00
|
|
|
|
media_type="audio/wav"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-08-20 16:10:23 +08:00
|
|
|
|
@app.get("/health_check")
|
|
|
|
|
|
async def health_check():
|
|
|
|
|
|
try:
|
|
|
|
|
|
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
|
|
|
|
|
|
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
|
|
|
|
|
|
c = torch.matmul(a, b)
|
|
|
|
|
|
if c.sum() == 10 * 20 * 10:
|
|
|
|
|
|
return {"status": "ok"}
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise HTTPException(status_code=503)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f'health_check failed')
|
|
|
|
|
|
raise HTTPException(status_code=503)
|
2025-08-12 14:15:41 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2025-08-20 16:10:23 +08:00
|
|
|
|
uvicorn.run("f5_server:app", host="0.0.0.0", port=80)
|