update README
This commit is contained in:
449
mlu_370-gpt-sovits/GPT-SoVITS/gsv_server.py
Normal file
449
mlu_370-gpt-sovits/GPT-SoVITS/gsv_server.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator
|
||||
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO"),
|
||||
)
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Optional, List
|
||||
import torch.nn.functional as F
|
||||
# torch.manual_seed(0)
|
||||
|
||||
# torch.backends.cuda.enable_flash_sdp(False)
|
||||
# torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
# torch.backends.cuda.enable_math_sdp(True)
|
||||
|
||||
# def custom_conv1d_forward(self, input: Tensor) -> Tensor:
|
||||
# if input.dtype == torch.float16 and input.device.type == 'cuda':
|
||||
# with torch.amp.autocast(input.device.type, dtype=torch.float):
|
||||
# return self._conv_forward(input, self.weight, self.bias).half()
|
||||
# else:
|
||||
# return self._conv_forward(input, self.weight, self.bias)
|
||||
|
||||
# torch.nn.Conv1d.forward = custom_conv1d_forward
|
||||
|
||||
# def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||
# if self.padding_mode != 'zeros':
|
||||
# raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
|
||||
|
||||
# assert isinstance(self.padding, tuple)
|
||||
# # One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||
# # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||
# num_spatial_dims = 1
|
||||
# output_padding = self._output_padding(
|
||||
# input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
|
||||
# num_spatial_dims, self.dilation) # type: ignore[arg-type]
|
||||
# if input.dtype == torch.float and input.device.type == 'cuda':
|
||||
# with torch.amp.autocast('cuda', dtype=torch.float16):
|
||||
# return F.conv_transpose1d(
|
||||
# input, self.weight, self.bias, self.stride, self.padding,
|
||||
# output_padding, self.groups, self.dilation).float()
|
||||
# else:
|
||||
# return F.conv_transpose1d(
|
||||
# input, self.weight, self.bias, self.stride, self.padding,
|
||||
# output_padding, self.groups, self.dilation)
|
||||
|
||||
# torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
|
||||
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
import subprocess
|
||||
import io
|
||||
import signal
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Request, Response, Body, HTTPException, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
from io import BytesIO
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
import time
|
||||
from fast_langdetect import detect_language
|
||||
import xml.etree.ElementTree as ET
|
||||
import base64
|
||||
import json
|
||||
#from redis.cluster import RedisCluster
|
||||
from redis import Redis
|
||||
|
||||
model_dir = os.getenv('MODEL_DIR', '/mnt/models/GPT-SoVITS')
|
||||
model_name = os.getenv('MODEL_NAME', 's1v3.ckpt')
|
||||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
|
||||
rds_key_prefix = 'tts:voice:'
|
||||
|
||||
# print(sys.path)
|
||||
i18n = I18nAuto()
|
||||
tts_pipeline = None
|
||||
|
||||
# @dataclass
|
||||
# class RefAudioMeta:
|
||||
# # audio: bytes
|
||||
# audio_path: str
|
||||
# text: str
|
||||
# lang: str
|
||||
# # slice_audio: Optional[bytes] = None
|
||||
# # slice_text: Optional[str] = None
|
||||
|
||||
# voice_dict: dict[str, RefAudioMeta] = {}
|
||||
|
||||
def init():
|
||||
global tts_pipeline
|
||||
|
||||
gsv_config = {
|
||||
# "version": "v2ProPlus",
|
||||
"custom": {
|
||||
"bert_base_path": os.path.join(model_dir, "chinese-roberta-wwm-ext-large"),
|
||||
"cnhuhbert_base_path": os.path.join(model_dir, "chinese-hubert-base"),
|
||||
"device": "mlu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": os.path.join(model_dir, model_name),
|
||||
"version": "v2ProPlus",
|
||||
"vits_weights_path": os.path.join(model_dir, "v2Pro/s2Gv2ProPlus.pth")
|
||||
}
|
||||
}
|
||||
tts_config = TTS_Config(gsv_config)
|
||||
# tts_config = TTS_Config(config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
|
||||
try:
|
||||
with open('/workspace/wav/ningguang.wav', 'rb') as f:
|
||||
mandarin_voice_bytes = f.read()
|
||||
text = "而这条街道,没有半分“不谐”之感,实属难得。"
|
||||
register_voice_to_redis(mandarin_voice_bytes, text, audio_key='zh')
|
||||
except:
|
||||
logger.warning("Failed to register zh voice, skipping registration.")
|
||||
|
||||
|
||||
try:
|
||||
with open('/workspace/wav/bbc_real_en.wav', 'rb') as f:
|
||||
en_voice_bytes = f.read()
|
||||
text = "Hello and welcome to Real Easy English. In this podcast, we have real conversations in easy English to help you learn."
|
||||
register_voice_to_redis(en_voice_bytes, text, audio_key='en')
|
||||
except:
|
||||
logger.warning("Failed to register en voice, skipping registration.")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
init()
|
||||
yield
|
||||
pass
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer.write(data.tobytes())
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer = BytesIO()
|
||||
sf.write(io_buffer, data, rate, format="wav")
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
"s16le", # 输入16位有符号小端整数PCM
|
||||
"-ar",
|
||||
str(rate), # 设置采样率
|
||||
"-ac",
|
||||
"1", # 单声道
|
||||
"-i",
|
||||
"pipe:0", # 从管道读取输入
|
||||
"-c:a",
|
||||
"aac", # 音频编码器为AAC
|
||||
"-b:a",
|
||||
"192k", # 比特率
|
||||
"-vn", # 不包含视频
|
||||
"-f",
|
||||
"adts", # 输出AAC数据流格式
|
||||
"pipe:1", # 将输出写入管道
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
from scipy.signal import resample
|
||||
|
||||
def resample_audio(data: np.ndarray, original_rate: int, target_rate: int):
|
||||
ori_dtype = data.dtype
|
||||
number_of_samples = int(len(data) * float(target_rate) / original_rate)
|
||||
resampled_data = resample(data, number_of_samples)
|
||||
return resampled_data.astype(ori_dtype)
|
||||
|
||||
def pack_audio_rate(io_buffer: BytesIO, data: np.ndarray, original_rate: int, target_rate: int, media_type: str):
|
||||
if target_rate and target_rate != original_rate:
|
||||
data = resample_audio(data, original_rate, target_rate)
|
||||
rate = target_rate
|
||||
else:
|
||||
rate = original_rate
|
||||
|
||||
if data.dtype == np.int16:
|
||||
data = data.astype(np.float32) / np.max(np.abs(data)) * 32767 # Normalize to int16 range
|
||||
data = data.astype(np.int16)
|
||||
else:
|
||||
data = data / np.max(np.abs(data))
|
||||
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def encode_audio_key(audio_bytes: bytes) -> str:
|
||||
return hashlib.md5(audio_bytes[:16000]).hexdigest()[:16]
|
||||
|
||||
def register_voice_to_redis(audio_bytes, text, audio_key: Optional[str] = None, force: bool = False):
|
||||
if audio_key is None:
|
||||
audio_key = encode_audio_key(audio_bytes)
|
||||
# Ensure ref_text ends with a proper sentence-ending punctuation
|
||||
if not text.endswith(". ") and not text.endswith("。"):
|
||||
if text.endswith("."):
|
||||
text += " "
|
||||
else:
|
||||
text += ". "
|
||||
|
||||
voice = {
|
||||
'audio': base64.b64encode(audio_bytes).decode('utf-8'),
|
||||
'text': text,
|
||||
}
|
||||
|
||||
already_exists = False
|
||||
#with RedisCluster.from_url(redis_url) as r:
|
||||
with Redis.from_url(redis_url) as r:
|
||||
redis_key = f'{rds_key_prefix}{audio_key}'
|
||||
resp = r.set(redis_key, json.dumps(voice), nx=not force)
|
||||
if not force and not resp:
|
||||
already_exists = True
|
||||
logger.warning(f"Voice with key {audio_key} already exists in Redis, skipping registration.")
|
||||
|
||||
logger.info(f"Registered voice with key: {audio_key}, text: {text}")
|
||||
|
||||
return audio_key, already_exists
|
||||
|
||||
@app.post("/register_voice")
|
||||
async def register_voice(
|
||||
audio: UploadFile = File(...),
|
||||
text: str = Form(...),
|
||||
audio_name: Optional[str] = Form(None),
|
||||
force: bool = Form(False)
|
||||
):
|
||||
audio_bytes = await audio.read()
|
||||
if audio_name == '':
|
||||
audio_name = None
|
||||
try:
|
||||
audio_key, already_exists = register_voice_to_redis(audio_bytes, text, audio_name, force=force)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register voice: {str(e)}")
|
||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||
|
||||
# warmup
|
||||
for _ in generate("流式语音合成,合成测试一", ref_audio_key=audio_key, fast_infer=1):
|
||||
logger.info("Warming up 1")
|
||||
for _ in generate("流式语音合成,合成测试二", ref_audio_key=audio_key, fast_infer=2):
|
||||
logger.info("Warming up 2")
|
||||
|
||||
response = {
|
||||
"status": "success" if not already_exists else "already_exists",
|
||||
"audio_key": audio_key
|
||||
}
|
||||
return JSONResponse(status_code=200, content=response)
|
||||
|
||||
|
||||
def generate(gen_text, text_lang="zh", ref_audio=None, ref_text=None, ref_audio_key=None, fast_infer=0):
|
||||
if ref_audio_key is not None:
|
||||
t1 = time.perf_counter()
|
||||
#with RedisCluster.from_url(redis_url) as r:
|
||||
with Redis.from_url(redis_url) as r:
|
||||
voice_data = r.get(f'{rds_key_prefix}{ref_audio_key}')
|
||||
if not voice_data:
|
||||
raise Exception(f'Voice {ref_audio_key} not found.')
|
||||
voice_data = json.loads(voice_data)
|
||||
t2 = time.perf_counter()
|
||||
logger.info(f"Loaded voice {ref_audio_key} from Redis in {t2 - t1:.3f} seconds")
|
||||
|
||||
if fast_infer >= 2 and 'slice_audio' in voice_data:
|
||||
ref_audio = base64.b64decode(voice_data['slice_audio'])
|
||||
ref_text = voice_data['slice_text']
|
||||
else:
|
||||
ref_audio = base64.b64decode(voice_data['audio'])
|
||||
ref_text = voice_data['text']
|
||||
with open(f"/workspace/wav/{ref_audio_key}.wav", "wb") as f:
|
||||
f.write(ref_audio)
|
||||
ref_audio_path = f"/workspace/wav/{ref_audio_key}.wav"
|
||||
ref_lang = detect_language(ref_text).lower() if ref_text else text_lang
|
||||
|
||||
elif ref_audio is not None:
|
||||
if isinstance(ref_audio, str):
|
||||
ref_audio_path = ref_audio
|
||||
else:
|
||||
audio_key = encode_audio_key(ref_audio)
|
||||
if not os.path.exists(f"/workspace/wav/{audio_key}.wav"):
|
||||
with open(f"/workspace/wav/{audio_key}.wav", "wb") as f:
|
||||
f.write(ref_audio)
|
||||
ref_audio_path = f"/workspace/wav/{audio_key}.wav"
|
||||
ref_lang = detect_language(ref_text).lower() if ref_text else text_lang
|
||||
|
||||
req = {
|
||||
"text": gen_text,
|
||||
"text_lang": text_lang,
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"prompt_text": ref_text,
|
||||
"prompt_lang": ref_lang,
|
||||
"text_split_method": "cut2",
|
||||
"media_type": "wav",
|
||||
"speed_factor": 1.0,
|
||||
"parallel_infer": False,
|
||||
"batch_size": 1,
|
||||
"split_bucket": False,
|
||||
"streaming_mode": True
|
||||
}
|
||||
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
return_fragment = req.get("return_fragment", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
|
||||
# check_res = check_params(req)
|
||||
# if check_res is not None:
|
||||
# return check_res
|
||||
|
||||
if streaming_mode or return_fragment:
|
||||
req["return_fragment"] = True
|
||||
|
||||
tts_generator = tts_pipeline.run(req)
|
||||
for sr, chunk in tts_generator:
|
||||
yield pack_audio_rate(BytesIO(), chunk, sr, target_rate=16000, media_type=None).getvalue()
|
||||
|
||||
@app.post("/synthesize")
|
||||
async def synthesize(request: Request):
|
||||
data = await request.json()
|
||||
text = data['text']
|
||||
audio_key = data['audio_key']
|
||||
language = data.get('language', 'zh')
|
||||
# fast_infer = data.get('fast_infer', 0)
|
||||
# if fast_infer == True:
|
||||
# fast_infer = 2
|
||||
# else:
|
||||
# fast_infer = int(fast_infer)
|
||||
|
||||
logger.info(f"Synthesizing text: {text}, audio_key: {audio_key}")
|
||||
|
||||
return StreamingResponse(
|
||||
generate(text, text_lang=language, ref_audio_key=audio_key),
|
||||
media_type="audio/wav"
|
||||
)
|
||||
|
||||
@app.post("/synthesize_with_audio")
|
||||
async def synthesize_with_audio(
|
||||
ref_audio: UploadFile = File(...),
|
||||
ref_text: str = Form(...),
|
||||
text: str = Form(...),
|
||||
lang: str = Form("zh"),
|
||||
fast_infer: int = Form(0)
|
||||
):
|
||||
logger.info(f"Synthesizing with audio, text: {text}, ref_text: {ref_text}, fast_infer: {fast_infer}")
|
||||
|
||||
audio_bytes = await ref_audio.read()
|
||||
return StreamingResponse(
|
||||
generate(text, text_lang=lang, ref_audio=audio_bytes, ref_text=ref_text),
|
||||
media_type="audio/wav"
|
||||
)
|
||||
|
||||
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
|
||||
@app.post("/")
|
||||
@app.post("/tts")
|
||||
def predict(ssml: str = Body(...)):
|
||||
try:
|
||||
root = ET.fromstring(ssml)
|
||||
voice_element = root.find(".//voice")
|
||||
if voice_element is not None:
|
||||
text = voice_element.text.strip()
|
||||
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
|
||||
voice_name = voice_element.get("name", "zh").strip()
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
|
||||
except ET.ParseError as e:
|
||||
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
generate(text, language, ref_audio_key=voice_name),
|
||||
media_type=f"audio/wav",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/ready")
|
||||
@app.get("/health")
|
||||
async def ready():
|
||||
return JSONResponse(status_code=200, content={"status": "ok"})
|
||||
|
||||
# @app.get("/health_check")
|
||||
# async def health_check():
|
||||
# try:
|
||||
# a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
|
||||
# b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
|
||||
# c = torch.matmul(a, b)
|
||||
# if c.sum() == 10 * 20 * 10:
|
||||
# return {"status": "ok"}
|
||||
# else:
|
||||
# raise HTTPException(status_code=503)
|
||||
# except Exception as e:
|
||||
# print(f'health_check failed')
|
||||
# raise HTTPException(status_code=503)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
uvicorn.run(app=app, host="0.0.0.0", port=80, workers=1)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
Reference in New Issue
Block a user