Files
enginex-bi_series-tts/bi_v100-gpt-sovits/gsv_server.py
2025-08-14 10:02:15 +08:00

246 lines
7.7 KiB
Python

import os
import sys
import traceback
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
import torch
from torch import Tensor
from typing import Optional, List
import torch.nn.functional as F
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
def custom_conv1d_forward(self, input: Tensor) -> Tensor:
if input.dtype == torch.float16 and input.device.type == 'cuda':
with torch.amp.autocast(input.device.type, dtype=torch.float):
return self._conv_forward(input, self.weight, self.bias).half()
else:
return self._conv_forward(input, self.weight, self.bias)
torch.nn.Conv1d.forward = custom_conv1d_forward
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
if input.dtype == torch.float and input.device.type == 'cuda':
with torch.amp.autocast('cuda', dtype=torch.float16):
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation).float()
else:
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
now_dir = os.getcwd()
os.chdir(f'{now_dir}/GPT-SoVITS')
now_dir = os.getcwd()
# sys.path.append(now_dir)
sys.path.insert(0, now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import sv
sv.sv_path = os.path.join(os.getenv("MODEL_DIR", "GPT_SoVITS/pretrained_models"), "sv/pretrained_eres2netv2w24s4ep4.ckpt")
import subprocess
import signal
import numpy as np
import soundfile as sf
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
import hashlib
from fast_langdetect import detect_language
model_dir = os.getenv('MODEL_DIR', '/mnt/models/GPT-SoVITS')
# print(sys.path)
i18n = I18nAuto()
tts_pipeline = None
def init():
global tts_pipeline
gsv_config = {
# "version": "v2ProPlus",
"custom": {
"bert_base_path": os.path.join(model_dir, "chinese-roberta-wwm-ext-large"),
"cnhuhbert_base_path": os.path.join(model_dir, "chinese-hubert-base"),
"device": "cuda",
"is_half": False,
"t2s_weights_path": os.path.join(model_dir, "s1v3.ckpt"),
"version": "v2ProPlus",
"vits_weights_path": os.path.join(model_dir, "v2Pro/s2Gv2ProPlus.pth")
}
}
tts_config = TTS_Config(gsv_config)
# tts_config = TTS_Config(config_path)
tts_pipeline = TTS(tts_config)
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format="wav")
return io_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen(
[
"ffmpeg",
"-f",
"s16le", # 输入16位有符号小端整数PCM
"-ar",
str(rate), # 设置采样率
"-ac",
"1", # 单声道
"-i",
"pipe:0", # 从管道读取输入
"-c:a",
"aac", # 音频编码器为AAC
"-b:a",
"192k", # 比特率
"-vn", # 不包含视频
"-f",
"adts", # 输出AAC数据流格式
"pipe:1", # 将输出写入管道
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
def encode_audio_key(audio_bytes: bytes) -> str:
return hashlib.md5(audio_bytes).hexdigest()[:16]
def tts_generate(gen_text, text_lang="zh", ref_audio=None, ref_text=None):
if isinstance(ref_audio, str):
ref_audio_path = ref_audio
else:
audio_key = encode_audio_key(ref_audio)
os.makedirs("/workspace/wav", exist_ok=True)
if not os.path.exists(f"/workspace/wav/{audio_key}.wav"):
with open(f"/workspace/wav/{audio_key}.wav", "wb") as f:
f.write(ref_audio)
ref_audio_path = f"/workspace/wav/{audio_key}.wav"
ref_lang = detect_language(ref_text).lower() if ref_text else text_lang
req = {
"text": gen_text,
"text_lang": text_lang,
"ref_audio_path": ref_audio_path,
"prompt_text": ref_text,
"prompt_lang": ref_lang,
"text_split_method": "cut2",
"media_type": "wav",
"speed_factor": 1.0,
"parallel_infer": False,
"batch_size": 1,
"split_bucket": False,
"streaming_mode": True
}
streaming_mode = req.get("streaming_mode", False)
return_fragment = req.get("return_fragment", False)
media_type = req.get("media_type", "wav")
# check_res = check_params(req)
# if check_res is not None:
# return check_res
if streaming_mode or return_fragment:
req["return_fragment"] = True
tts_generator = tts_pipeline.run(req)
for sr, chunk in tts_generator:
yield pack_audio(BytesIO(), chunk, sr, media_type=None).getvalue()
# return 32kHz pcm16
@app.post("/generate")
async def generate(
ref_audio: UploadFile = File(...),
ref_text: str = Form(...),
text: str = Form(...),
lang: str = Form("zh")
):
audio_bytes = await ref_audio.read()
return StreamingResponse(
tts_generate(text, text_lang=lang, ref_audio=audio_bytes, ref_text=ref_text),
media_type="audio/wav"
)
@app.get("/ready")
@app.get("/health")
async def ready():
return JSONResponse(status_code=200, content={"status": "ok"})
if __name__ == "__main__":
try:
uvicorn.run(app=app, host="0.0.0.0", port=80, workers=1)
except Exception:
traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM)
exit(0)