Files
enginex-mlu370-tts/mlu_370-f5-tts/f5_server.py
2025-09-10 10:47:02 +08:00

317 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
model_dir = os.getenv("MODEL_DIR", "/mounted_model")
model_name = os.getenv("MODEL_NAME", "model.safetensors")
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
import torch
torch.set_num_threads(4)
from torch import Tensor
from typing import Optional, List
import torch.nn.functional as F
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
infer_batch_process,
)
from omegaconf import OmegaConf
from hydra.utils import get_class
import torch
import re
import numpy as np
import soundfile as sf
import torchaudio
from scipy import signal
import io
import time
from fastapi import FastAPI, Request, Response, Body, HTTPException
from fastapi import UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import os
import hashlib
import xml.etree.ElementTree as ET
from typing import Union
vocoder_dir = os.getenv('VOCODER_DIR', '/app/charactr/vocos-mel-24khz')
speed = float(os.getenv('SPEED', 1.0))
ema_model = None
vocoder = None
voice_dict = {}
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'mlu'
TARGET_SR = 16000
N_ZEROS = 20
# ===== New: AMP dtype config (default: bf16; set AMP_DTYPE=fp16 to force fp16) =====
# AMP_DTYPE_ENV = os.getenv("AMP_DTYPE", "bf16").lower()
# def _amp_dtype_for_mlu():
# return torch.float16 if AMP_DTYPE_ENV in ("fp16", "float16", "16") else torch.bfloat16
# def mlu_autocast():
# # torch.autocast 支持 device_type="mlu"
# return torch.autocast(device_type="mlu", dtype=_amp_dtype_for_mlu())
# std_ref_audio_file = os.path.join(model_dir, 'ref_audio.wav')
# std_ref_text_file = os.path.join(model_dir, 'ref_text.txt')
std_ref_audio_file = '/app/ref_audio.wav'
std_ref_text_file = '/app/ref_text.txt'
std_ref_audio = None
std_ref_text = None
def init():
global ema_model, vocoder
global std_ref_audio, std_ref_text
logger.info(f'{device=}')
# load vocoder
vocoder_name = 'vocos'
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=vocoder_dir, device=device)
# load TTS model
model_cfg = OmegaConf.load('/app/F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml')
model_cls = get_class(f'f5_tts.model.{model_cfg.model.backbone}')
model_arc = model_cfg.model.arch
ckpt_file = os.path.join(model_dir, model_name)
vocab_file = os.path.join(model_dir, 'vocab.txt')
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
try:
ema_model.eval()
except Exception:
pass
with open(std_ref_audio_file, 'rb') as f:
std_ref_audio = f.read()
with open(std_ref_text_file, 'r', encoding='utf-8') as f:
std_ref_text = f.read().strip()
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
@app.get("/health")
@app.get("/ready")
async def ready():
return JSONResponse(status_code=200, content={"message": "success"})
def encode_audio_key(audio_bytes: bytes) -> str:
return hashlib.md5(audio_bytes[:16000]).hexdigest()[:16]
@app.post("/register_voice")
async def register_voice(
audio: UploadFile = File(...),
text: str = Form(...)
):
global voice_dict
audio_bytes = await audio.read()
audio_key = encode_audio_key(audio_bytes)
# Ensure ref_text ends with a proper sentence-ending punctuation
if not text.endswith(". ") and not text.endswith(""):
if text.endswith("."):
text += " "
else:
text += ". "
voice_dict[audio_key] = {
'ref_audio': audio_bytes,
'ref_text': text.strip()
}
# warmup
for _ in generate("流式语音合成,合成测试", audio_key, fast_infer=2):
logger.info("Warming up")
response = {
"status": "success",
"audio_key": audio_key
}
return JSONResponse(status_code=200, content=response)
symbols = """,.!?;:()[]{}<>,。!?;:【】《》……'"“”_—"""
def contains_words(text):
return any(char not in symbols for char in text)
def split_text(text, max_chars=135, cut_short_first=False):
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk and contains_words(current_chunk):
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk and contains_words(current_chunk):
chunks.append(current_chunk.strip())
if cut_short_first:
first_sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", chunks[0])
first = first_sentences[0].strip()
rest = "".join(first_sentences[1:]).strip()
first_chunk = [first, rest] if rest else [first]
final_chunks = first_chunk + chunks[1:]
else:
final_chunks = chunks
return final_chunks
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
audio_resampled = signal.resample(audio, number_of_samples)
if audio.dtype == np.float32:
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
audio_resampled = (audio_resampled * 32767).astype(np.int16)
return audio_resampled
def generate(gen_text, ref_audio_key, fast_infer=0):
global voice_dict, ema_model, vocoder
ref_audio_ = voice_dict[ref_audio_key]['ref_audio']
ref_text_ = voice_dict[ref_audio_key]['ref_text']
nfe_step = 16
if fast_infer >= 1:
nfe_step = 7
# nonuniform_step = True
# if fast_infer >= 2:
# ref_audio_ = voice_dict[ref_audio_key].get('ref_audio_slice', ref_audio_)
# ref_text_ = voice_dict[ref_audio_key].get('ref_text_slice', ref_text_)
audio, sr = torchaudio.load(io.BytesIO(ref_audio_))
max_chars = int(len(ref_text_.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0))
for gen_audio, gen_sr in infer_batch_process(
(audio, sr),
ref_text_,
gen_text_batches,
ema_model,
vocoder,
device=device,
streaming=True,
chunk_size=int(24e6),
nfe_step=nfe_step,
speed=speed,
):
yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes()
def generate_with_audio(gen_text, ref_audio, ref_text, fast_infer=0):
global ema_model, vocoder
if not contains_words(gen_text):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
yield audio
return
nfe_step = 16
if fast_infer >= 1:
nfe_step = 7
audio, sr = torchaudio.load(io.BytesIO(ref_audio))
max_chars = min(int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)), 135)
gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0))
for gen_audio, gen_sr in infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
ema_model,
vocoder,
device=device,
streaming=True,
chunk_size=int(24e6),
nfe_step=nfe_step,
speed=speed,
):
yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes()
@app.post("/synthesize")
async def synthesize(request: Request):
data = await request.json()
text = data['text']
audio_key = data['audio_key']
fast_infer = data.get('fast_infer', 0)
if fast_infer == True:
fast_infer = 2
else:
fast_infer = int(fast_infer)
# logger.info(f"Synthesizing text: {text}, audio_key: {audio_key}, fast_infer: {fast_infer}")
if not contains_words(text):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
return Response(audio, media_type='audio/wav')
global voice_dict
if audio_key not in voice_dict:
raise HTTPException(status_code=400, detail="Invalid audio key")
return StreamingResponse(generate(text, audio_key, fast_infer), media_type="audio/wav")
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
@app.post("/tts")
def predict(ssml: str = Body(...), fast_infer: Union[bool, int] = 0):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
transcription = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
fast_infer = int(fast_infer)
return StreamingResponse(
generate_with_audio(transcription, std_ref_audio, std_ref_text, fast_infer),
media_type="audio/wav"
)
# @app.get("/health_check")
# async def health_check():
# try:
# a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
# b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
# c = torch.matmul(a, b)
# if c.sum() == 10 * 20 * 10:
# return {"status": "ok"}
# else:
# raise HTTPException(status_code=503)
# except Exception as e:
# print(f'health_check failed')
# raise HTTPException(status_code=503)
if __name__ == "__main__":
uvicorn.run("f5_server:app", host="0.0.0.0", port=80)