import torch torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) from torch import Tensor from typing import Optional, List import torch.nn.functional as F # def custom_conv1d_forward(self, input: Tensor, debug=False) -> Tensor: # with torch.amp.autocast(input.device.type, dtype=torch.float): # return self._conv_forward(input, self.weight, self.bias) # torch.nn.Conv1d.forward = custom_conv1d_forward def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. num_spatial_dims = 1 output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] num_spatial_dims, self.dilation) # type: ignore[arg-type] with torch.amp.autocast('cuda', dtype=torch.float16): return F.conv_transpose1d( input, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation).float() torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward from f5_tts.infer.utils_infer import ( load_vocoder, load_model, chunk_text, infer_batch_process, ) from omegaconf import OmegaConf from hydra.utils import get_class import torchaudio import io from fastapi import FastAPI from fastapi import UploadFile, File, Form from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn import os import logging logging.basicConfig( format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO"), ) logger = logging.getLogger(__file__) model_dir = os.getenv('MODEL_DIR', '/models/SWivid/F5-TTS') vocoder_dir = os.getenv('VOCODER_DIR', '/models/charactr/vocos-mel-24khz') ema_model = None vocoder = None device = 'cuda' if torch.cuda.is_available() else 'cpu' def init(): global ema_model, vocoder # load vocoder vocoder_name = 'vocos' vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=vocoder_dir, device=device) # load TTS model model_cfg = OmegaConf.load('/workspace/F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml') model_cls = get_class(f'f5_tts.model.{model_cfg.model.backbone}') model_arc = model_cfg.model.arch ckpt_file = os.path.join(model_dir, 'F5TTS_v1_Base/model_1250000.safetensors') vocab_file = os.path.join(model_dir, 'F5TTS_v1_Base/vocab.txt') ema_model = load_model( model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device ) @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) def tts_generate(gen_text, ref_audio, ref_text): global ema_model, vocoder audio, sr = torchaudio.load(io.BytesIO(ref_audio)) max_chars = min(int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)), 135) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) for gen_audio, gen_sr in infer_batch_process( (audio, sr), ref_text, gen_text_batches, ema_model, vocoder, device=device, streaming=True, chunk_size=int(24e6), # nfe_step=16, ): yield gen_audio.tobytes() # return 24kHz pcm16 @app.post("/generate") async def generate( ref_audio: UploadFile = File(...), ref_text: str = Form(...), text: str = Form(...) ): audio_bytes = await ref_audio.read() return StreamingResponse( tts_generate(text, ref_audio=audio_bytes, ref_text=ref_text), media_type="audio/wav" ) @app.get("/ready") @app.get("/health") async def ready(): return JSONResponse(status_code=200, content={"status": "ok"}) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=80)