import os import sys import traceback from typing import Generator import logging logging.basicConfig( format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO"), ) logger = logging.getLogger(__file__) import torch from torch import Tensor from typing import Optional, List import torch.nn.functional as F # torch.manual_seed(0) now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import subprocess import io import signal import numpy as np import soundfile as sf from fastapi import FastAPI, Request, Response, Body, HTTPException, UploadFile, File, Form from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from dataclasses import dataclass import hashlib import time from fast_langdetect import detect_language import xml.etree.ElementTree as ET import base64 import json #from redis.cluster import RedisCluster from redis import Redis model_dir = os.getenv('MODEL_DIR', '/mnt/models/GPT-SoVITS') model_name = os.getenv('MODEL_NAME', 's1v3.ckpt') redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") rds_key_prefix = 'tts:voice:' # print(sys.path) i18n = I18nAuto() tts_pipeline = None # @dataclass # class RefAudioMeta: # # audio: bytes # audio_path: str # text: str # lang: str # # slice_audio: Optional[bytes] = None # # slice_text: Optional[str] = None # voice_dict: dict[str, RefAudioMeta] = {} def init(): global tts_pipeline gsv_config = { # "version": "v2ProPlus", "custom": { "bert_base_path": os.path.join(model_dir, "chinese-roberta-wwm-ext-large"), "cnhuhbert_base_path": os.path.join(model_dir, "chinese-hubert-base"), "device": "npu", "is_half": False, "t2s_weights_path": os.path.join(model_dir, model_name), "version": "v2ProPlus", "vits_weights_path": os.path.join(model_dir, "v2Pro/s2Gv2ProPlus.pth") } } tts_config = TTS_Config(gsv_config) # tts_config = TTS_Config(config_path) print(tts_config) tts_pipeline = TTS(tts_config) try: with open('/workspace/wav/ningguang.wav', 'rb') as f: mandarin_voice_bytes = f.read() text = "而这条街道,没有半分“不谐”之感,实属难得。" register_voice_to_redis(mandarin_voice_bytes, text, audio_key='zh') except: logger.warning("Failed to register zh voice, skipping registration.") try: with open('/workspace/wav/bbc_real_en.wav', 'rb') as f: en_voice_bytes = f.read() text = "Hello and welcome to Real Easy English. In this podcast, we have real conversations in easy English to help you learn." register_voice_to_redis(en_voice_bytes, text, audio_key='en') except: logger.warning("Failed to register en voice, skipping registration.") @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: audio_file.write(data) return io_buffer def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): io_buffer.write(data.tobytes()) return io_buffer def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): io_buffer = BytesIO() sf.write(io_buffer, data, rate, format="wav") return io_buffer def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): process = subprocess.Popen( [ "ffmpeg", "-f", "s16le", # 输入16位有符号小端整数PCM "-ar", str(rate), # 设置采样率 "-ac", "1", # 单声道 "-i", "pipe:0", # 从管道读取输入 "-c:a", "aac", # 音频编码器为AAC "-b:a", "192k", # 比特率 "-vn", # 不包含视频 "-f", "adts", # 输出AAC数据流格式 "pipe:1", # 将输出写入管道 ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) out, _ = process.communicate(input=data.tobytes()) io_buffer.write(out) return io_buffer def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): if media_type == "ogg": io_buffer = pack_ogg(io_buffer, data, rate) elif media_type == "aac": io_buffer = pack_aac(io_buffer, data, rate) elif media_type == "wav": io_buffer = pack_wav(io_buffer, data, rate) else: io_buffer = pack_raw(io_buffer, data, rate) io_buffer.seek(0) return io_buffer from scipy.signal import resample def resample_audio(data: np.ndarray, original_rate: int, target_rate: int): ori_dtype = data.dtype number_of_samples = int(len(data) * float(target_rate) / original_rate) resampled_data = resample(data, number_of_samples) return resampled_data.astype(ori_dtype) def pack_audio_rate(io_buffer: BytesIO, data: np.ndarray, original_rate: int, target_rate: int, media_type: str): if target_rate and target_rate != original_rate: data = resample_audio(data, original_rate, target_rate) rate = target_rate else: rate = original_rate if data.dtype == np.int16: data = data.astype(np.float32) / np.max(np.abs(data)) * 32767 # Normalize to int16 range data = data.astype(np.int16) else: data = data / np.max(np.abs(data)) if media_type == "ogg": io_buffer = pack_ogg(io_buffer, data, rate) elif media_type == "aac": io_buffer = pack_aac(io_buffer, data, rate) elif media_type == "wav": io_buffer = pack_wav(io_buffer, data, rate) else: io_buffer = pack_raw(io_buffer, data, rate) io_buffer.seek(0) return io_buffer def encode_audio_key(audio_bytes: bytes) -> str: return hashlib.md5(audio_bytes[:16000]).hexdigest()[:16] def register_voice_to_redis(audio_bytes, text, audio_key: Optional[str] = None, force: bool = False): if audio_key is None: audio_key = encode_audio_key(audio_bytes) # Ensure ref_text ends with a proper sentence-ending punctuation if not text.endswith(". ") and not text.endswith("。"): if text.endswith("."): text += " " else: text += ". " voice = { 'audio': base64.b64encode(audio_bytes).decode('utf-8'), 'text': text, } already_exists = False #with RedisCluster.from_url(redis_url) as r: with Redis.from_url(redis_url) as r: redis_key = f'{rds_key_prefix}{audio_key}' resp = r.set(redis_key, json.dumps(voice), nx=not force) if not force and not resp: already_exists = True logger.warning(f"Voice with key {audio_key} already exists in Redis, skipping registration.") logger.info(f"Registered voice with key: {audio_key}, text: {text}") return audio_key, already_exists @app.post("/register_voice") async def register_voice( audio: UploadFile = File(...), text: str = Form(...), audio_name: Optional[str] = Form(None), force: bool = Form(False) ): audio_bytes = await audio.read() if audio_name == '': audio_name = None try: audio_key, already_exists = register_voice_to_redis(audio_bytes, text, audio_name, force=force) except Exception as e: logger.warning(f"Failed to register voice: {str(e)}") return JSONResponse(status_code=400, content={"error": str(e)}) # warmup for _ in generate("流式语音合成,合成测试一", ref_audio_key=audio_key, fast_infer=1): logger.info("Warming up 1") for _ in generate("流式语音合成,合成测试二", ref_audio_key=audio_key, fast_infer=2): logger.info("Warming up 2") response = { "status": "success" if not already_exists else "already_exists", "audio_key": audio_key } return JSONResponse(status_code=200, content=response) def generate(gen_text, text_lang="zh", ref_audio=None, ref_text=None, ref_audio_key=None, fast_infer=0): if ref_audio_key is not None: t1 = time.perf_counter() #with RedisCluster.from_url(redis_url) as r: with Redis.from_url(redis_url) as r: voice_data = r.get(f'{rds_key_prefix}{ref_audio_key}') if not voice_data: raise Exception(f'Voice {ref_audio_key} not found.') voice_data = json.loads(voice_data) t2 = time.perf_counter() logger.info(f"Loaded voice {ref_audio_key} from Redis in {t2 - t1:.3f} seconds") if fast_infer >= 2 and 'slice_audio' in voice_data: ref_audio = base64.b64decode(voice_data['slice_audio']) ref_text = voice_data['slice_text'] else: ref_audio = base64.b64decode(voice_data['audio']) ref_text = voice_data['text'] with open(f"/workspace/wav/{ref_audio_key}.wav", "wb") as f: f.write(ref_audio) ref_audio_path = f"/workspace/wav/{ref_audio_key}.wav" ref_lang = detect_language(ref_text).lower() if ref_text else text_lang elif ref_audio is not None: if isinstance(ref_audio, str): ref_audio_path = ref_audio else: audio_key = encode_audio_key(ref_audio) if not os.path.exists(f"/workspace/wav/{audio_key}.wav"): with open(f"/workspace/wav/{audio_key}.wav", "wb") as f: f.write(ref_audio) ref_audio_path = f"/workspace/wav/{audio_key}.wav" ref_lang = detect_language(ref_text).lower() if ref_text else text_lang req = { "text": gen_text, "text_lang": text_lang, "ref_audio_path": ref_audio_path, "prompt_text": ref_text, "prompt_lang": ref_lang, "text_split_method": "cut2", "media_type": "wav", "speed_factor": 1.0, "parallel_infer": False, "batch_size": 1, "split_bucket": False, "streaming_mode": True } streaming_mode = req.get("streaming_mode", False) return_fragment = req.get("return_fragment", False) media_type = req.get("media_type", "wav") # check_res = check_params(req) # if check_res is not None: # return check_res if streaming_mode or return_fragment: req["return_fragment"] = True tts_generator = tts_pipeline.run(req) for sr, chunk in tts_generator: yield pack_audio_rate(BytesIO(), chunk, sr, target_rate=16000, media_type=None).getvalue() @app.post("/synthesize") async def synthesize(request: Request): data = await request.json() text = data['text'] audio_key = data['audio_key'] language = data.get('language', 'zh') # fast_infer = data.get('fast_infer', 0) # if fast_infer == True: # fast_infer = 2 # else: # fast_infer = int(fast_infer) logger.info(f"Synthesizing text: {text}, audio_key: {audio_key}") return StreamingResponse( generate(text, text_lang=language, ref_audio_key=audio_key), media_type="audio/wav" ) @app.post("/synthesize_with_audio") async def synthesize_with_audio( ref_audio: UploadFile = File(...), ref_text: str = Form(...), text: str = Form(...), lang: str = Form("zh"), fast_infer: int = Form(0) ): logger.info(f"Synthesizing with audio, text: {text}, ref_text: {ref_text}, fast_infer: {fast_infer}") audio_bytes = await ref_audio.read() return StreamingResponse( generate(text, text_lang=lang, ref_audio=audio_bytes, ref_text=ref_text), media_type="audio/wav" ) xml_namespace = "{http://www.w3.org/XML/1998/namespace}" @app.post("/") @app.post("/tts") def predict(ssml: str = Body(...)): try: root = ET.fromstring(ssml) voice_element = root.find(".//voice") if voice_element is not None: text = voice_element.text.strip() language = voice_element.get(f'{xml_namespace}lang', "zh").strip() voice_name = voice_element.get("name", "zh").strip() else: return JSONResponse(status_code=400, content={"message": "Invalid SSML format: element not found."}) except ET.ParseError as e: return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)}) return StreamingResponse( generate(text, language, ref_audio_key=voice_name), media_type=f"audio/wav", ) @app.get("/ready") @app.get("/health") async def ready(): return JSONResponse(status_code=200, content={"status": "ok"}) @app.get("/health_check") async def health_check(): try: a = torch.ones(10, 20, dtype=torch.float32, device='npu') b = torch.ones(20, 10, dtype=torch.float32, device='npu') c = torch.matmul(a, b) if c.sum() == 10 * 20 * 10: return {"status": "ok"} else: raise HTTPException(status_code=503) except Exception as e: print(f'health_check failed') raise HTTPException(status_code=503) if __name__ == "__main__": try: uvicorn.run(app=app, host="0.0.0.0", port=80, workers=1) except Exception: traceback.print_exc() os.kill(os.getpid(), signal.SIGTERM) exit(0)