Files
enginex-bi_series-tts/bi_v100-f5-tts/f5_server.py
2025-08-14 10:02:15 +08:00

134 lines
4.2 KiB
Python

import torch
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
from torch import Tensor
from typing import Optional, List
import torch.nn.functional as F
# def custom_conv1d_forward(self, input: Tensor, debug=False) -> Tensor:
# with torch.amp.autocast(input.device.type, dtype=torch.float):
# return self._conv_forward(input, self.weight, self.bias)
# torch.nn.Conv1d.forward = custom_conv1d_forward
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
with torch.amp.autocast('cuda', dtype=torch.float16):
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation).float()
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
chunk_text,
infer_batch_process,
)
from omegaconf import OmegaConf
from hydra.utils import get_class
import torchaudio
import io
from fastapi import FastAPI
from fastapi import UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import os
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
model_dir = os.getenv('MODEL_DIR', '/models/SWivid/F5-TTS')
vocoder_dir = os.getenv('VOCODER_DIR', '/models/charactr/vocos-mel-24khz')
ema_model = None
vocoder = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def init():
global ema_model, vocoder
# load vocoder
vocoder_name = 'vocos'
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=vocoder_dir, device=device)
# load TTS model
model_cfg = OmegaConf.load('/workspace/F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml')
model_cls = get_class(f'f5_tts.model.{model_cfg.model.backbone}')
model_arc = model_cfg.model.arch
ckpt_file = os.path.join(model_dir, 'F5TTS_v1_Base/model_1250000.safetensors')
vocab_file = os.path.join(model_dir, 'F5TTS_v1_Base/vocab.txt')
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
def tts_generate(gen_text, ref_audio, ref_text):
global ema_model, vocoder
audio, sr = torchaudio.load(io.BytesIO(ref_audio))
max_chars = min(int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)), 135)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for gen_audio, gen_sr in infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
ema_model,
vocoder,
device=device,
streaming=True,
chunk_size=int(24e6),
# nfe_step=16,
):
yield gen_audio.tobytes()
# return 24kHz pcm16
@app.post("/generate")
async def generate(
ref_audio: UploadFile = File(...),
ref_text: str = Form(...),
text: str = Form(...)
):
audio_bytes = await ref_audio.read()
return StreamingResponse(
tts_generate(text, ref_audio=audio_bytes, ref_text=ref_text),
media_type="audio/wav"
)
@app.get("/ready")
@app.get("/health")
async def ready():
return JSONResponse(status_code=200, content={"status": "ok"})
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=80)