import os model_dir = os.getenv("MODEL_DIR", "/mounted_model") model_name = os.getenv("MODEL_NAME", "model.safetensors") import logging logging.basicConfig( format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO"), ) logger = logging.getLogger(__file__) import torch torch.set_num_threads(4) from torch import Tensor from typing import Optional, List import torch.nn.functional as F from f5_tts.infer.utils_infer import ( load_vocoder, load_model, preprocess_ref_audio_text, infer_process, infer_batch_process, ) from omegaconf import OmegaConf from hydra.utils import get_class import torch import re import numpy as np import soundfile as sf import torchaudio from scipy import signal import io import time from fastapi import FastAPI, Request, Response, Body, HTTPException from fastapi import UploadFile, File, Form from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn import os import hashlib import xml.etree.ElementTree as ET from typing import Union vocoder_dir = os.getenv('VOCODER_DIR', '/app/charactr/vocos-mel-24khz') speed = float(os.getenv('SPEED', 1.0)) ema_model = None vocoder = None voice_dict = {} # device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'mlu' TARGET_SR = 16000 N_ZEROS = 20 # ===== New: AMP dtype config (default: bf16; set AMP_DTYPE=fp16 to force fp16) ===== # AMP_DTYPE_ENV = os.getenv("AMP_DTYPE", "bf16").lower() # def _amp_dtype_for_mlu(): # return torch.float16 if AMP_DTYPE_ENV in ("fp16", "float16", "16") else torch.bfloat16 # def mlu_autocast(): # # torch.autocast 支持 device_type="mlu" # return torch.autocast(device_type="mlu", dtype=_amp_dtype_for_mlu()) # std_ref_audio_file = os.path.join(model_dir, 'ref_audio.wav') # std_ref_text_file = os.path.join(model_dir, 'ref_text.txt') std_ref_audio_file = '/app/ref_audio.wav' std_ref_text_file = '/app/ref_text.txt' std_ref_audio = None std_ref_text = None def init(): global ema_model, vocoder global std_ref_audio, std_ref_text logger.info(f'{device=}') # load vocoder vocoder_name = 'vocos' vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=vocoder_dir, device=device) # load TTS model model_cfg = OmegaConf.load('/app/F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml') model_cls = get_class(f'f5_tts.model.{model_cfg.model.backbone}') model_arc = model_cfg.model.arch ckpt_file = os.path.join(model_dir, model_name) vocab_file = os.path.join(model_dir, 'vocab.txt') ema_model = load_model( model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device ) try: ema_model.eval() except Exception: pass with open(std_ref_audio_file, 'rb') as f: std_ref_audio = f.read() with open(std_ref_text_file, 'r', encoding='utf-8') as f: std_ref_text = f.read().strip() @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) @app.get("/health") @app.get("/ready") async def ready(): return JSONResponse(status_code=200, content={"message": "success"}) def encode_audio_key(audio_bytes: bytes) -> str: return hashlib.md5(audio_bytes[:16000]).hexdigest()[:16] @app.post("/register_voice") async def register_voice( audio: UploadFile = File(...), text: str = Form(...) ): global voice_dict audio_bytes = await audio.read() audio_key = encode_audio_key(audio_bytes) # Ensure ref_text ends with a proper sentence-ending punctuation if not text.endswith(". ") and not text.endswith("。"): if text.endswith("."): text += " " else: text += ". " voice_dict[audio_key] = { 'ref_audio': audio_bytes, 'ref_text': text.strip() } # warmup for _ in generate("流式语音合成,合成测试", audio_key, fast_infer=2): logger.info("Warming up") response = { "status": "success", "audio_key": audio_key } return JSONResponse(status_code=200, content=response) symbols = """,.!?;:()[]{}<>,。!?;:【】《》……'"’“”_—""" def contains_words(text): return any(char not in symbols for char in text) def split_text(text, max_chars=135, cut_short_first=False): sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text) sentences = [s.strip() for s in sentences if s.strip()] chunks = [] current_chunk = "" for sentence in sentences: if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence else: if current_chunk and contains_words(current_chunk): chunks.append(current_chunk.strip()) current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence if current_chunk and contains_words(current_chunk): chunks.append(current_chunk.strip()) if cut_short_first: first_sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", chunks[0]) first = first_sentences[0].strip() rest = "".join(first_sentences[1:]).strip() first_chunk = [first, rest] if rest else [first] final_chunks = first_chunk + chunks[1:] else: final_chunks = chunks return final_chunks def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray: number_of_samples = int(len(audio) * float(target_sr) / ori_sr) audio_resampled = signal.resample(audio, number_of_samples) if audio.dtype == np.float32: audio_resampled = np.clip(audio_resampled, -1.0, 1.0) audio_resampled = (audio_resampled * 32767).astype(np.int16) return audio_resampled def generate(gen_text, ref_audio_key, fast_infer=0): global voice_dict, ema_model, vocoder ref_audio_ = voice_dict[ref_audio_key]['ref_audio'] ref_text_ = voice_dict[ref_audio_key]['ref_text'] nfe_step = 16 if fast_infer >= 1: nfe_step = 7 # nonuniform_step = True # if fast_infer >= 2: # ref_audio_ = voice_dict[ref_audio_key].get('ref_audio_slice', ref_audio_) # ref_text_ = voice_dict[ref_audio_key].get('ref_text_slice', ref_text_) audio, sr = torchaudio.load(io.BytesIO(ref_audio_)) max_chars = int(len(ref_text_.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)) gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0)) for gen_audio, gen_sr in infer_batch_process( (audio, sr), ref_text_, gen_text_batches, ema_model, vocoder, device=device, streaming=True, chunk_size=int(24e6), nfe_step=nfe_step, speed=speed, ): yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes() def generate_with_audio(gen_text, ref_audio, ref_text, fast_infer=0): global ema_model, vocoder if not contains_words(gen_text): audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes() yield audio return nfe_step = 16 if fast_infer >= 1: nfe_step = 7 audio, sr = torchaudio.load(io.BytesIO(ref_audio)) max_chars = min(int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)), 135) gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0)) for gen_audio, gen_sr in infer_batch_process( (audio, sr), ref_text, gen_text_batches, ema_model, vocoder, device=device, streaming=True, chunk_size=int(24e6), nfe_step=nfe_step, speed=speed, ): yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes() @app.post("/synthesize") async def synthesize(request: Request): data = await request.json() text = data['text'] audio_key = data['audio_key'] fast_infer = data.get('fast_infer', 0) if fast_infer == True: fast_infer = 2 else: fast_infer = int(fast_infer) # logger.info(f"Synthesizing text: {text}, audio_key: {audio_key}, fast_infer: {fast_infer}") if not contains_words(text): audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes() return Response(audio, media_type='audio/wav') global voice_dict if audio_key not in voice_dict: raise HTTPException(status_code=400, detail="Invalid audio key") return StreamingResponse(generate(text, audio_key, fast_infer), media_type="audio/wav") xml_namespace = "{http://www.w3.org/XML/1998/namespace}" @app.post("/tts") def predict(ssml: str = Body(...), fast_infer: Union[bool, int] = 0): try: root = ET.fromstring(ssml) voice_element = root.find(".//voice") if voice_element is not None: transcription = voice_element.text.strip() language = voice_element.get(f'{xml_namespace}lang', "zh").strip() # voice_name = voice_element.get("name", "zh-f-soft-1").strip() else: return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) except ET.ParseError as e: return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) fast_infer = int(fast_infer) return StreamingResponse( generate_with_audio(transcription, std_ref_audio, std_ref_text, fast_infer), media_type="audio/wav" ) # @app.get("/health_check") # async def health_check(): # try: # a = torch.ones(10, 20, dtype=torch.float32, device='cuda') # b = torch.ones(20, 10, dtype=torch.float32, device='cuda') # c = torch.matmul(a, b) # if c.sum() == 10 * 20 * 10: # return {"status": "ok"} # else: # raise HTTPException(status_code=503) # except Exception as e: # print(f'health_check failed') # raise HTTPException(status_code=503) if __name__ == "__main__": uvicorn.run("f5_server:app", host="0.0.0.0", port=80)