246 lines
7.7 KiB
Python
246 lines
7.7 KiB
Python
import os
|
|
import sys
|
|
import traceback
|
|
|
|
import logging
|
|
logging.basicConfig(
|
|
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=os.environ.get("LOGLEVEL", "INFO"),
|
|
)
|
|
logger = logging.getLogger(__file__)
|
|
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from typing import Optional, List
|
|
import torch.nn.functional as F
|
|
|
|
torch.backends.cuda.enable_flash_sdp(False)
|
|
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
|
torch.backends.cuda.enable_math_sdp(True)
|
|
|
|
def custom_conv1d_forward(self, input: Tensor) -> Tensor:
|
|
if input.dtype == torch.float16 and input.device.type == 'cuda':
|
|
with torch.amp.autocast(input.device.type, dtype=torch.float):
|
|
return self._conv_forward(input, self.weight, self.bias).half()
|
|
else:
|
|
return self._conv_forward(input, self.weight, self.bias)
|
|
|
|
torch.nn.Conv1d.forward = custom_conv1d_forward
|
|
|
|
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
|
if self.padding_mode != 'zeros':
|
|
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
|
|
|
|
assert isinstance(self.padding, tuple)
|
|
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
|
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
|
num_spatial_dims = 1
|
|
output_padding = self._output_padding(
|
|
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
|
|
num_spatial_dims, self.dilation) # type: ignore[arg-type]
|
|
if input.dtype == torch.float and input.device.type == 'cuda':
|
|
with torch.amp.autocast('cuda', dtype=torch.float16):
|
|
return F.conv_transpose1d(
|
|
input, self.weight, self.bias, self.stride, self.padding,
|
|
output_padding, self.groups, self.dilation).float()
|
|
else:
|
|
return F.conv_transpose1d(
|
|
input, self.weight, self.bias, self.stride, self.padding,
|
|
output_padding, self.groups, self.dilation)
|
|
|
|
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
|
|
|
|
|
|
now_dir = os.getcwd()
|
|
os.chdir(f'{now_dir}/GPT-SoVITS')
|
|
now_dir = os.getcwd()
|
|
# sys.path.append(now_dir)
|
|
sys.path.insert(0, now_dir)
|
|
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
|
|
|
import sv
|
|
sv.sv_path = os.path.join(os.getenv("MODEL_DIR", "GPT_SoVITS/pretrained_models"), "sv/pretrained_eres2netv2w24s4ep4.ckpt")
|
|
|
|
import subprocess
|
|
import signal
|
|
import numpy as np
|
|
import soundfile as sf
|
|
from fastapi import FastAPI, UploadFile, File, Form
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from contextlib import asynccontextmanager
|
|
import uvicorn
|
|
from io import BytesIO
|
|
from tools.i18n.i18n import I18nAuto
|
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
|
import hashlib
|
|
from fast_langdetect import detect_language
|
|
|
|
model_dir = os.getenv('MODEL_DIR', '/mnt/models/GPT-SoVITS')
|
|
|
|
# print(sys.path)
|
|
i18n = I18nAuto()
|
|
tts_pipeline = None
|
|
|
|
def init():
|
|
global tts_pipeline
|
|
|
|
gsv_config = {
|
|
# "version": "v2ProPlus",
|
|
"custom": {
|
|
"bert_base_path": os.path.join(model_dir, "chinese-roberta-wwm-ext-large"),
|
|
"cnhuhbert_base_path": os.path.join(model_dir, "chinese-hubert-base"),
|
|
"device": "cuda",
|
|
"is_half": False,
|
|
"t2s_weights_path": os.path.join(model_dir, "s1v3.ckpt"),
|
|
"version": "v2ProPlus",
|
|
"vits_weights_path": os.path.join(model_dir, "v2Pro/s2Gv2ProPlus.pth")
|
|
}
|
|
}
|
|
tts_config = TTS_Config(gsv_config)
|
|
# tts_config = TTS_Config(config_path)
|
|
tts_pipeline = TTS(tts_config)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
init()
|
|
yield
|
|
pass
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
|
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
|
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
|
audio_file.write(data)
|
|
return io_buffer
|
|
|
|
|
|
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
|
io_buffer.write(data.tobytes())
|
|
return io_buffer
|
|
|
|
|
|
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
|
io_buffer = BytesIO()
|
|
sf.write(io_buffer, data, rate, format="wav")
|
|
return io_buffer
|
|
|
|
|
|
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
|
process = subprocess.Popen(
|
|
[
|
|
"ffmpeg",
|
|
"-f",
|
|
"s16le", # 输入16位有符号小端整数PCM
|
|
"-ar",
|
|
str(rate), # 设置采样率
|
|
"-ac",
|
|
"1", # 单声道
|
|
"-i",
|
|
"pipe:0", # 从管道读取输入
|
|
"-c:a",
|
|
"aac", # 音频编码器为AAC
|
|
"-b:a",
|
|
"192k", # 比特率
|
|
"-vn", # 不包含视频
|
|
"-f",
|
|
"adts", # 输出AAC数据流格式
|
|
"pipe:1", # 将输出写入管道
|
|
],
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
)
|
|
out, _ = process.communicate(input=data.tobytes())
|
|
io_buffer.write(out)
|
|
return io_buffer
|
|
|
|
|
|
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
|
|
if media_type == "ogg":
|
|
io_buffer = pack_ogg(io_buffer, data, rate)
|
|
elif media_type == "aac":
|
|
io_buffer = pack_aac(io_buffer, data, rate)
|
|
elif media_type == "wav":
|
|
io_buffer = pack_wav(io_buffer, data, rate)
|
|
else:
|
|
io_buffer = pack_raw(io_buffer, data, rate)
|
|
io_buffer.seek(0)
|
|
return io_buffer
|
|
|
|
|
|
def encode_audio_key(audio_bytes: bytes) -> str:
|
|
return hashlib.md5(audio_bytes).hexdigest()[:16]
|
|
|
|
def tts_generate(gen_text, text_lang="zh", ref_audio=None, ref_text=None):
|
|
if isinstance(ref_audio, str):
|
|
ref_audio_path = ref_audio
|
|
else:
|
|
audio_key = encode_audio_key(ref_audio)
|
|
os.makedirs("/workspace/wav", exist_ok=True)
|
|
if not os.path.exists(f"/workspace/wav/{audio_key}.wav"):
|
|
with open(f"/workspace/wav/{audio_key}.wav", "wb") as f:
|
|
f.write(ref_audio)
|
|
ref_audio_path = f"/workspace/wav/{audio_key}.wav"
|
|
ref_lang = detect_language(ref_text).lower() if ref_text else text_lang
|
|
|
|
req = {
|
|
"text": gen_text,
|
|
"text_lang": text_lang,
|
|
"ref_audio_path": ref_audio_path,
|
|
"prompt_text": ref_text,
|
|
"prompt_lang": ref_lang,
|
|
"text_split_method": "cut2",
|
|
"media_type": "wav",
|
|
"speed_factor": 1.0,
|
|
"parallel_infer": False,
|
|
"batch_size": 1,
|
|
"split_bucket": False,
|
|
"streaming_mode": True
|
|
}
|
|
|
|
streaming_mode = req.get("streaming_mode", False)
|
|
return_fragment = req.get("return_fragment", False)
|
|
media_type = req.get("media_type", "wav")
|
|
|
|
# check_res = check_params(req)
|
|
# if check_res is not None:
|
|
# return check_res
|
|
|
|
if streaming_mode or return_fragment:
|
|
req["return_fragment"] = True
|
|
|
|
tts_generator = tts_pipeline.run(req)
|
|
for sr, chunk in tts_generator:
|
|
yield pack_audio(BytesIO(), chunk, sr, media_type=None).getvalue()
|
|
|
|
# return 32kHz pcm16
|
|
@app.post("/generate")
|
|
async def generate(
|
|
ref_audio: UploadFile = File(...),
|
|
ref_text: str = Form(...),
|
|
text: str = Form(...),
|
|
lang: str = Form("zh")
|
|
):
|
|
audio_bytes = await ref_audio.read()
|
|
return StreamingResponse(
|
|
tts_generate(text, text_lang=lang, ref_audio=audio_bytes, ref_text=ref_text),
|
|
media_type="audio/wav"
|
|
)
|
|
|
|
@app.get("/ready")
|
|
@app.get("/health")
|
|
async def ready():
|
|
return JSONResponse(status_code=200, content={"status": "ok"})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
uvicorn.run(app=app, host="0.0.0.0", port=80, workers=1)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
os.kill(os.getpid(), signal.SIGTERM)
|
|
exit(0)
|