Files
2025-08-20 16:10:23 +08:00

342 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
model_dir = os.getenv("MODEL_DIR", "/mounted_model")
model_name = os.getenv("MODEL_NAME", "model.safetensors")
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
# enable custom patcher if available
patcher_path = os.path.join(model_dir, "custom_patcher.py")
if os.path.exists(patcher_path):
import shutil
shutil.copyfile(patcher_path, "custom_patcher.py")
try:
import custom_patcher
logger.info("Custom patcher has been applied.")
except ImportError:
logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.")
else:
logger.info("No custom_patcher found.")
import torch
torch.set_num_threads(4)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
from torch import Tensor
from typing import Optional, List
import torch.nn.functional as F
# def custom_conv1d_forward(self, input: Tensor, debug=False) -> Tensor:
# with torch.amp.autocast(input.device.type, dtype=torch.float):
# return self._conv_forward(input, self.weight, self.bias)
# torch.nn.Conv1d.forward = custom_conv1d_forward
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
with torch.amp.autocast('cuda', dtype=torch.float16):
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation).float()
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
infer_batch_process,
)
from omegaconf import OmegaConf
from hydra.utils import get_class
import torch
import re
import numpy as np
import soundfile as sf
import torchaudio
from scipy import signal
import io
import time
from fastapi import FastAPI, Request, Response, Body, HTTPException
from fastapi import UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import os
import hashlib
import xml.etree.ElementTree as ET
from typing import Union
vocoder_dir = os.getenv('VOCODER_DIR', '/workspace/charactr/vocos-mel-24khz')
speed = float(os.getenv('SPEED', 1.0))
ema_model = None
vocoder = None
voice_dict = {}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
TARGET_SR = 16000
N_ZEROS = 20
# std_ref_audio_file = os.path.join(model_dir, 'ref_audio.wav')
# std_ref_text_file = os.path.join(model_dir, 'ref_text.txt')
std_ref_audio_file = '/workspace/ref_audio.wav'
std_ref_text_file = '/workspace/ref_text.txt'
std_ref_audio = None
std_ref_text = None
def init():
global ema_model, vocoder
global std_ref_audio, std_ref_text
logger.info(f'{device=}')
# load vocoder
vocoder_name = 'vocos'
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=vocoder_dir, device=device)
# load TTS model
model_cfg = OmegaConf.load('/workspace/F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml')
model_cls = get_class(f'f5_tts.model.{model_cfg.model.backbone}')
model_arc = model_cfg.model.arch
ckpt_file = os.path.join(model_dir, model_name)
vocab_file = os.path.join(model_dir, 'vocab.txt')
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
with open(std_ref_audio_file, 'rb') as f:
std_ref_audio = f.read()
with open(std_ref_text_file, 'r', encoding='utf-8') as f:
std_ref_text = f.read().strip()
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
@app.get("/health")
@app.get("/ready")
async def ready():
return JSONResponse(status_code=200, content={"message": "success"})
def encode_audio_key(audio_bytes: bytes) -> str:
return hashlib.md5(audio_bytes[:16000]).hexdigest()[:16]
@app.post("/register_voice")
async def register_voice(
audio: UploadFile = File(...),
text: str = Form(...)
):
global voice_dict
audio_bytes = await audio.read()
audio_key = encode_audio_key(audio_bytes)
# Ensure ref_text ends with a proper sentence-ending punctuation
if not text.endswith(". ") and not text.endswith(""):
if text.endswith("."):
text += " "
else:
text += ". "
voice_dict[audio_key] = {
'ref_audio': audio_bytes,
'ref_text': text.strip()
}
# warmup
for _ in generate("流式语音合成,合成测试", audio_key, fast_infer=2):
logger.info("Warming up")
response = {
"status": "success",
"audio_key": audio_key
}
return JSONResponse(status_code=200, content=response)
symbols = """,.!?;:()[]{}<>,。!?;:【】《》……'"“”_—"""
def contains_words(text):
return any(char not in symbols for char in text)
def split_text(text, max_chars=135, cut_short_first=False):
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk and contains_words(current_chunk):
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk and contains_words(current_chunk):
chunks.append(current_chunk.strip())
if cut_short_first:
first_sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", chunks[0])
first = first_sentences[0].strip()
rest = "".join(first_sentences[1:]).strip()
first_chunk = [first, rest] if rest else [first]
final_chunks = first_chunk + chunks[1:]
else:
final_chunks = chunks
return final_chunks
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
audio_resampled = signal.resample(audio, number_of_samples)
if audio.dtype == np.float32:
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
audio_resampled = (audio_resampled * 32767).astype(np.int16)
return audio_resampled
def generate(gen_text, ref_audio_key, fast_infer=0):
global voice_dict, ema_model, vocoder
ref_audio_ = voice_dict[ref_audio_key]['ref_audio']
ref_text_ = voice_dict[ref_audio_key]['ref_text']
nfe_step = 16
if fast_infer >= 1:
nfe_step = 7
# nonuniform_step = True
# if fast_infer >= 2:
# ref_audio_ = voice_dict[ref_audio_key].get('ref_audio_slice', ref_audio_)
# ref_text_ = voice_dict[ref_audio_key].get('ref_text_slice', ref_text_)
audio, sr = torchaudio.load(io.BytesIO(ref_audio_))
max_chars = int(len(ref_text_.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0))
for gen_audio, gen_sr in infer_batch_process(
(audio, sr),
ref_text_,
gen_text_batches,
ema_model,
vocoder,
device=device,
streaming=True,
chunk_size=int(24e6),
nfe_step=nfe_step,
speed=speed,
):
yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes()
def generate_with_audio(gen_text, ref_audio, ref_text, fast_infer=0):
global ema_model, vocoder
if not contains_words(gen_text):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
yield audio
return
nfe_step = 16
if fast_infer >= 1:
nfe_step = 7
audio, sr = torchaudio.load(io.BytesIO(ref_audio))
max_chars = min(int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)), 135)
gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0))
for gen_audio, gen_sr in infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
ema_model,
vocoder,
device=device,
streaming=True,
chunk_size=int(24e6),
nfe_step=nfe_step,
speed=speed,
):
yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes()
@app.post("/synthesize")
async def synthesize(request: Request):
data = await request.json()
text = data['text']
audio_key = data['audio_key']
fast_infer = data.get('fast_infer', 0)
if fast_infer == True:
fast_infer = 2
else:
fast_infer = int(fast_infer)
# logger.info(f"Synthesizing text: {text}, audio_key: {audio_key}, fast_infer: {fast_infer}")
if not contains_words(text):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
return Response(audio, media_type='audio/wav')
global voice_dict
if audio_key not in voice_dict:
raise HTTPException(status_code=400, detail="Invalid audio key")
return StreamingResponse(generate(text, audio_key, fast_infer), media_type="audio/wav")
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
@app.post("/tts")
def predict(ssml: str = Body(...), fast_infer: Union[bool, int] = 0):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
transcription = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
fast_infer = int(fast_infer)
return StreamingResponse(
generate_with_audio(transcription, std_ref_audio, std_ref_text, fast_infer),
media_type="audio/wav"
)
@app.get("/health_check")
async def health_check():
try:
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
c = torch.matmul(a, b)
if c.sum() == 10 * 20 * 10:
return {"status": "ok"}
else:
raise HTTPException(status_code=503)
except Exception as e:
print(f'health_check failed')
raise HTTPException(status_code=503)
if __name__ == "__main__":
uvicorn.run("f5_server:app", host="0.0.0.0", port=80)