merge code repo for f5 and gpt and kokoro
This commit is contained in:
133
f5_server.py
Normal file
133
f5_server.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import torch
|
||||
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Optional, List
|
||||
import torch.nn.functional as F
|
||||
|
||||
# def custom_conv1d_forward(self, input: Tensor, debug=False) -> Tensor:
|
||||
# with torch.amp.autocast(input.device.type, dtype=torch.float):
|
||||
# return self._conv_forward(input, self.weight, self.bias)
|
||||
|
||||
# torch.nn.Conv1d.forward = custom_conv1d_forward
|
||||
|
||||
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||
if self.padding_mode != 'zeros':
|
||||
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
|
||||
|
||||
assert isinstance(self.padding, tuple)
|
||||
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||
num_spatial_dims = 1
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
|
||||
num_spatial_dims, self.dilation) # type: ignore[arg-type]
|
||||
with torch.amp.autocast('cuda', dtype=torch.float16):
|
||||
return F.conv_transpose1d(
|
||||
input, self.weight, self.bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation).float()
|
||||
|
||||
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
load_vocoder,
|
||||
load_model,
|
||||
chunk_text,
|
||||
infer_batch_process,
|
||||
)
|
||||
from omegaconf import OmegaConf
|
||||
from hydra.utils import get_class
|
||||
import torchaudio
|
||||
import io
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
import os
|
||||
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO"),
|
||||
)
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
model_dir = os.getenv('MODEL_DIR', '/models/SWivid/F5-TTS')
|
||||
vocoder_dir = os.getenv('VOCODER_DIR', '/models/charactr/vocos-mel-24khz')
|
||||
|
||||
ema_model = None
|
||||
vocoder = None
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
def init():
|
||||
global ema_model, vocoder
|
||||
# load vocoder
|
||||
vocoder_name = 'vocos'
|
||||
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=vocoder_dir, device=device)
|
||||
|
||||
# load TTS model
|
||||
model_cfg = OmegaConf.load('/workspace/F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml')
|
||||
model_cls = get_class(f'f5_tts.model.{model_cfg.model.backbone}')
|
||||
model_arc = model_cfg.model.arch
|
||||
ckpt_file = os.path.join(model_dir, 'F5TTS_v1_Base/model_1250000.safetensors')
|
||||
vocab_file = os.path.join(model_dir, 'F5TTS_v1_Base/vocab.txt')
|
||||
ema_model = load_model(
|
||||
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
init()
|
||||
yield
|
||||
pass
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
def tts_generate(gen_text, ref_audio, ref_text):
|
||||
global ema_model, vocoder
|
||||
|
||||
audio, sr = torchaudio.load(io.BytesIO(ref_audio))
|
||||
max_chars = min(int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)), 135)
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for gen_audio, gen_sr in infer_batch_process(
|
||||
(audio, sr),
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
ema_model,
|
||||
vocoder,
|
||||
device=device,
|
||||
streaming=True,
|
||||
chunk_size=int(24e6),
|
||||
# nfe_step=16,
|
||||
):
|
||||
yield gen_audio.tobytes()
|
||||
|
||||
# return 24kHz pcm16
|
||||
@app.post("/generate")
|
||||
async def generate(
|
||||
ref_audio: UploadFile = File(...),
|
||||
ref_text: str = Form(...),
|
||||
text: str = Form(...)
|
||||
):
|
||||
audio_bytes = await ref_audio.read()
|
||||
return StreamingResponse(
|
||||
tts_generate(text, ref_audio=audio_bytes, ref_text=ref_text),
|
||||
media_type="audio/wav"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/ready")
|
||||
@app.get("/health")
|
||||
async def ready():
|
||||
return JSONResponse(status_code=200, content={"status": "ok"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=80)
|
||||
Reference in New Issue
Block a user