Files
2025-08-20 17:05:36 +08:00

450 lines
15 KiB
Python

import os
import sys
import traceback
from typing import Generator
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
import torch
from torch import Tensor
from typing import Optional, List
import torch.nn.functional as F
# torch.manual_seed(0)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
def custom_conv1d_forward(self, input: Tensor) -> Tensor:
if input.dtype == torch.float16 and input.device.type == 'cuda':
with torch.amp.autocast(input.device.type, dtype=torch.float):
return self._conv_forward(input, self.weight, self.bias).half()
else:
return self._conv_forward(input, self.weight, self.bias)
torch.nn.Conv1d.forward = custom_conv1d_forward
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
if input.dtype == torch.float and input.device.type == 'cuda':
with torch.amp.autocast('cuda', dtype=torch.float16):
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation).float()
else:
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import subprocess
import io
import signal
import numpy as np
import soundfile as sf
from fastapi import FastAPI, Request, Response, Body, HTTPException, UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from dataclasses import dataclass
import hashlib
import time
from fast_langdetect import detect_language
import xml.etree.ElementTree as ET
import base64
import json
#from redis.cluster import RedisCluster
from redis import Redis
model_dir = os.getenv('MODEL_DIR', '/mnt/models/GPT-SoVITS')
model_name = os.getenv('MODEL_NAME', 's1v3.ckpt')
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
rds_key_prefix = 'tts:voice:'
# print(sys.path)
i18n = I18nAuto()
tts_pipeline = None
# @dataclass
# class RefAudioMeta:
# # audio: bytes
# audio_path: str
# text: str
# lang: str
# # slice_audio: Optional[bytes] = None
# # slice_text: Optional[str] = None
# voice_dict: dict[str, RefAudioMeta] = {}
def init():
global tts_pipeline
gsv_config = {
# "version": "v2ProPlus",
"custom": {
"bert_base_path": os.path.join(model_dir, "chinese-roberta-wwm-ext-large"),
"cnhuhbert_base_path": os.path.join(model_dir, "chinese-hubert-base"),
"device": "cuda",
"is_half": False,
"t2s_weights_path": os.path.join(model_dir, model_name),
"version": "v2ProPlus",
"vits_weights_path": os.path.join(model_dir, "v2Pro/s2Gv2ProPlus.pth")
}
}
tts_config = TTS_Config(gsv_config)
# tts_config = TTS_Config(config_path)
print(tts_config)
tts_pipeline = TTS(tts_config)
try:
with open('/workspace/wav/ningguang.wav', 'rb') as f:
mandarin_voice_bytes = f.read()
text = "而这条街道,没有半分“不谐”之感,实属难得。"
register_voice_to_redis(mandarin_voice_bytes, text, audio_key='zh')
except:
logger.warning("Failed to register zh voice, skipping registration.")
try:
with open('/workspace/wav/bbc_real_en.wav', 'rb') as f:
en_voice_bytes = f.read()
text = "Hello and welcome to Real Easy English. In this podcast, we have real conversations in easy English to help you learn."
register_voice_to_redis(en_voice_bytes, text, audio_key='en')
except:
logger.warning("Failed to register en voice, skipping registration.")
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format="wav")
return io_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen(
[
"ffmpeg",
"-f",
"s16le", # 输入16位有符号小端整数PCM
"-ar",
str(rate), # 设置采样率
"-ac",
"1", # 单声道
"-i",
"pipe:0", # 从管道读取输入
"-c:a",
"aac", # 音频编码器为AAC
"-b:a",
"192k", # 比特率
"-vn", # 不包含视频
"-f",
"adts", # 输出AAC数据流格式
"pipe:1", # 将输出写入管道
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
from scipy.signal import resample
def resample_audio(data: np.ndarray, original_rate: int, target_rate: int):
ori_dtype = data.dtype
number_of_samples = int(len(data) * float(target_rate) / original_rate)
resampled_data = resample(data, number_of_samples)
return resampled_data.astype(ori_dtype)
def pack_audio_rate(io_buffer: BytesIO, data: np.ndarray, original_rate: int, target_rate: int, media_type: str):
if target_rate and target_rate != original_rate:
data = resample_audio(data, original_rate, target_rate)
rate = target_rate
else:
rate = original_rate
if data.dtype == np.int16:
data = data.astype(np.float32) / np.max(np.abs(data)) * 32767 # Normalize to int16 range
data = data.astype(np.int16)
else:
data = data / np.max(np.abs(data))
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
def encode_audio_key(audio_bytes: bytes) -> str:
return hashlib.md5(audio_bytes[:16000]).hexdigest()[:16]
def register_voice_to_redis(audio_bytes, text, audio_key: Optional[str] = None, force: bool = False):
if audio_key is None:
audio_key = encode_audio_key(audio_bytes)
# Ensure ref_text ends with a proper sentence-ending punctuation
if not text.endswith(". ") and not text.endswith(""):
if text.endswith("."):
text += " "
else:
text += ". "
voice = {
'audio': base64.b64encode(audio_bytes).decode('utf-8'),
'text': text,
}
already_exists = False
#with RedisCluster.from_url(redis_url) as r:
with Redis.from_url(redis_url) as r:
redis_key = f'{rds_key_prefix}{audio_key}'
resp = r.set(redis_key, json.dumps(voice), nx=not force)
if not force and not resp:
already_exists = True
logger.warning(f"Voice with key {audio_key} already exists in Redis, skipping registration.")
logger.info(f"Registered voice with key: {audio_key}, text: {text}")
return audio_key, already_exists
@app.post("/register_voice")
async def register_voice(
audio: UploadFile = File(...),
text: str = Form(...),
audio_name: Optional[str] = Form(None),
force: bool = Form(False)
):
audio_bytes = await audio.read()
if audio_name == '':
audio_name = None
try:
audio_key, already_exists = register_voice_to_redis(audio_bytes, text, audio_name, force=force)
except Exception as e:
logger.warning(f"Failed to register voice: {str(e)}")
return JSONResponse(status_code=400, content={"error": str(e)})
# warmup
for _ in generate("流式语音合成,合成测试一", ref_audio_key=audio_key, fast_infer=1):
logger.info("Warming up 1")
for _ in generate("流式语音合成,合成测试二", ref_audio_key=audio_key, fast_infer=2):
logger.info("Warming up 2")
response = {
"status": "success" if not already_exists else "already_exists",
"audio_key": audio_key
}
return JSONResponse(status_code=200, content=response)
def generate(gen_text, text_lang="zh", ref_audio=None, ref_text=None, ref_audio_key=None, fast_infer=0):
if ref_audio_key is not None:
t1 = time.perf_counter()
#with RedisCluster.from_url(redis_url) as r:
with Redis.from_url(redis_url) as r:
voice_data = r.get(f'{rds_key_prefix}{ref_audio_key}')
if not voice_data:
raise Exception(f'Voice {ref_audio_key} not found.')
voice_data = json.loads(voice_data)
t2 = time.perf_counter()
logger.info(f"Loaded voice {ref_audio_key} from Redis in {t2 - t1:.3f} seconds")
if fast_infer >= 2 and 'slice_audio' in voice_data:
ref_audio = base64.b64decode(voice_data['slice_audio'])
ref_text = voice_data['slice_text']
else:
ref_audio = base64.b64decode(voice_data['audio'])
ref_text = voice_data['text']
with open(f"/workspace/wav/{ref_audio_key}.wav", "wb") as f:
f.write(ref_audio)
ref_audio_path = f"/workspace/wav/{ref_audio_key}.wav"
ref_lang = detect_language(ref_text).lower() if ref_text else text_lang
elif ref_audio is not None:
if isinstance(ref_audio, str):
ref_audio_path = ref_audio
else:
audio_key = encode_audio_key(ref_audio)
if not os.path.exists(f"/workspace/wav/{audio_key}.wav"):
with open(f"/workspace/wav/{audio_key}.wav", "wb") as f:
f.write(ref_audio)
ref_audio_path = f"/workspace/wav/{audio_key}.wav"
ref_lang = detect_language(ref_text).lower() if ref_text else text_lang
req = {
"text": gen_text,
"text_lang": text_lang,
"ref_audio_path": ref_audio_path,
"prompt_text": ref_text,
"prompt_lang": ref_lang,
"text_split_method": "cut2",
"media_type": "wav",
"speed_factor": 1.0,
"parallel_infer": False,
"batch_size": 1,
"split_bucket": False,
"streaming_mode": True
}
streaming_mode = req.get("streaming_mode", False)
return_fragment = req.get("return_fragment", False)
media_type = req.get("media_type", "wav")
# check_res = check_params(req)
# if check_res is not None:
# return check_res
if streaming_mode or return_fragment:
req["return_fragment"] = True
tts_generator = tts_pipeline.run(req)
for sr, chunk in tts_generator:
yield pack_audio_rate(BytesIO(), chunk, sr, target_rate=16000, media_type=None).getvalue()
@app.post("/synthesize")
async def synthesize(request: Request):
data = await request.json()
text = data['text']
audio_key = data['audio_key']
language = data.get('language', 'zh')
# fast_infer = data.get('fast_infer', 0)
# if fast_infer == True:
# fast_infer = 2
# else:
# fast_infer = int(fast_infer)
logger.info(f"Synthesizing text: {text}, audio_key: {audio_key}")
return StreamingResponse(
generate(text, text_lang=language, ref_audio_key=audio_key),
media_type="audio/wav"
)
@app.post("/synthesize_with_audio")
async def synthesize_with_audio(
ref_audio: UploadFile = File(...),
ref_text: str = Form(...),
text: str = Form(...),
lang: str = Form("zh"),
fast_infer: int = Form(0)
):
logger.info(f"Synthesizing with audio, text: {text}, ref_text: {ref_text}, fast_infer: {fast_infer}")
audio_bytes = await ref_audio.read()
return StreamingResponse(
generate(text, text_lang=lang, ref_audio=audio_bytes, ref_text=ref_text),
media_type="audio/wav"
)
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
@app.post("/")
@app.post("/tts")
def predict(ssml: str = Body(...)):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
text = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
voice_name = voice_element.get("name", "zh").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
return StreamingResponse(
generate(text, language, ref_audio_key=voice_name),
media_type=f"audio/wav",
)
@app.get("/ready")
@app.get("/health")
async def ready():
return JSONResponse(status_code=200, content={"status": "ok"})
@app.get("/health_check")
async def health_check():
try:
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
c = torch.matmul(a, b)
if c.sum() == 10 * 20 * 10:
return {"status": "ok"}
else:
raise HTTPException(status_code=503)
except Exception as e:
print(f'health_check failed')
raise HTTPException(status_code=503)
if __name__ == "__main__":
try:
uvicorn.run(app=app, host="0.0.0.0", port=80, workers=1)
except Exception:
traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM)
exit(0)