import os import sys import traceback import logging logging.basicConfig( format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO"), ) logger = logging.getLogger(__file__) import torch from torch import Tensor from typing import Optional, List import torch.nn.functional as F torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) def custom_conv1d_forward(self, input: Tensor) -> Tensor: if input.dtype == torch.float16 and input.device.type == 'cuda': with torch.amp.autocast(input.device.type, dtype=torch.float): return self._conv_forward(input, self.weight, self.bias).half() else: return self._conv_forward(input, self.weight, self.bias) torch.nn.Conv1d.forward = custom_conv1d_forward def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. num_spatial_dims = 1 output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] num_spatial_dims, self.dilation) # type: ignore[arg-type] if input.dtype == torch.float and input.device.type == 'cuda': with torch.amp.autocast('cuda', dtype=torch.float16): return F.conv_transpose1d( input, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation).float() else: return F.conv_transpose1d( input, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward now_dir = os.getcwd() os.chdir(f'{now_dir}/GPT-SoVITS') now_dir = os.getcwd() # sys.path.append(now_dir) sys.path.insert(0, now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import sv sv.sv_path = os.path.join(os.getenv("MODEL_DIR", "GPT_SoVITS/pretrained_models"), "sv/pretrained_eres2netv2w24s4ep4.ckpt") import subprocess import signal import numpy as np import soundfile as sf from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import StreamingResponse, JSONResponse from contextlib import asynccontextmanager import uvicorn from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config import hashlib from fast_langdetect import detect_language model_dir = os.getenv('MODEL_DIR', '/mnt/models/GPT-SoVITS') # print(sys.path) i18n = I18nAuto() tts_pipeline = None def init(): global tts_pipeline gsv_config = { # "version": "v2ProPlus", "custom": { "bert_base_path": os.path.join(model_dir, "chinese-roberta-wwm-ext-large"), "cnhuhbert_base_path": os.path.join(model_dir, "chinese-hubert-base"), "device": "cuda", "is_half": False, "t2s_weights_path": os.path.join(model_dir, "s1v3.ckpt"), "version": "v2ProPlus", "vits_weights_path": os.path.join(model_dir, "v2Pro/s2Gv2ProPlus.pth") } } tts_config = TTS_Config(gsv_config) # tts_config = TTS_Config(config_path) tts_pipeline = TTS(tts_config) @asynccontextmanager async def lifespan(app: FastAPI): init() yield pass app = FastAPI(lifespan=lifespan) ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: audio_file.write(data) return io_buffer def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): io_buffer.write(data.tobytes()) return io_buffer def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): io_buffer = BytesIO() sf.write(io_buffer, data, rate, format="wav") return io_buffer def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): process = subprocess.Popen( [ "ffmpeg", "-f", "s16le", # 输入16位有符号小端整数PCM "-ar", str(rate), # 设置采样率 "-ac", "1", # 单声道 "-i", "pipe:0", # 从管道读取输入 "-c:a", "aac", # 音频编码器为AAC "-b:a", "192k", # 比特率 "-vn", # 不包含视频 "-f", "adts", # 输出AAC数据流格式 "pipe:1", # 将输出写入管道 ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) out, _ = process.communicate(input=data.tobytes()) io_buffer.write(out) return io_buffer def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): if media_type == "ogg": io_buffer = pack_ogg(io_buffer, data, rate) elif media_type == "aac": io_buffer = pack_aac(io_buffer, data, rate) elif media_type == "wav": io_buffer = pack_wav(io_buffer, data, rate) else: io_buffer = pack_raw(io_buffer, data, rate) io_buffer.seek(0) return io_buffer def encode_audio_key(audio_bytes: bytes) -> str: return hashlib.md5(audio_bytes).hexdigest()[:16] def tts_generate(gen_text, text_lang="zh", ref_audio=None, ref_text=None): if isinstance(ref_audio, str): ref_audio_path = ref_audio else: audio_key = encode_audio_key(ref_audio) os.makedirs("/workspace/wav", exist_ok=True) if not os.path.exists(f"/workspace/wav/{audio_key}.wav"): with open(f"/workspace/wav/{audio_key}.wav", "wb") as f: f.write(ref_audio) ref_audio_path = f"/workspace/wav/{audio_key}.wav" ref_lang = detect_language(ref_text).lower() if ref_text else text_lang req = { "text": gen_text, "text_lang": text_lang, "ref_audio_path": ref_audio_path, "prompt_text": ref_text, "prompt_lang": ref_lang, "text_split_method": "cut2", "media_type": "wav", "speed_factor": 1.0, "parallel_infer": False, "batch_size": 1, "split_bucket": False, "streaming_mode": True } streaming_mode = req.get("streaming_mode", False) return_fragment = req.get("return_fragment", False) media_type = req.get("media_type", "wav") # check_res = check_params(req) # if check_res is not None: # return check_res if streaming_mode or return_fragment: req["return_fragment"] = True tts_generator = tts_pipeline.run(req) for sr, chunk in tts_generator: yield pack_audio(BytesIO(), chunk, sr, media_type=None).getvalue() # return 32kHz pcm16 @app.post("/generate") async def generate( ref_audio: UploadFile = File(...), ref_text: str = Form(...), text: str = Form(...), lang: str = Form("zh") ): audio_bytes = await ref_audio.read() return StreamingResponse( tts_generate(text, text_lang=lang, ref_audio=audio_bytes, ref_text=ref_text), media_type="audio/wav" ) @app.get("/ready") @app.get("/health") async def ready(): return JSONResponse(status_code=200, content={"status": "ok"}) if __name__ == "__main__": try: uvicorn.run(app=app, host="0.0.0.0", port=80, workers=1) except Exception: traceback.print_exc() os.kill(os.getpid(), signal.SIGTERM) exit(0)