update kokoro fix
This commit is contained in:
@@ -118,7 +118,7 @@ else:
|
||||
|
||||
| 模型名称 | 模型类型 | 适配状态 | 沐曦 MetaX C500运行时间/s | Nvidia A100运行时间/s |
|
||||
| ---------- | ---------------------- | -------- | ----------------- | --------------------- |
|
||||
| kokoro | StyleTTS 2, ISTFTNet | 成功 | 202.6 | 5.4 |
|
||||
| kokoro | StyleTTS 2, ISTFTNet | 成功 | 4.3 | 5.4 |
|
||||
| f5-TTS | DiT, ConvNeXt V2 | 成功 | 7.1 | 5.4 |
|
||||
| gpt-sovits | VITS | 成功 | 24.4 | 20.5 |
|
||||
| matcha | OT-CFM, Transformer | 成功 | 2.9 | 3.2 |
|
||||
|
||||
@@ -42,5 +42,4 @@ curl --request POST "http://localhost:8080/tts" \
|
||||
|
||||
---
|
||||
|
||||
- 无 Patch 能运行,但是生成的音频会有噪声
|
||||
- Patch: 将 decoder/istftnet 固定为 CPU FP32(禁用 AMP/TF32),修复 GPU “打字机”噪声
|
||||
- Patch: 修复torch.istft 复数运算出错问题 ,修复 GPU “打字机”噪声
|
||||
|
||||
@@ -16,6 +16,10 @@ from scipy.signal import resample
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.nn as nn
|
||||
import types
|
||||
|
||||
from typing import Optional, List
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
@@ -95,56 +99,55 @@ def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int):
|
||||
|
||||
# ================== decoder/istftnet 补丁 ==================
|
||||
|
||||
def _to_cpu_fp32(obj):
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().to("cpu", dtype=torch.float32)
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return type(obj)(_to_cpu_fp32(x) for x in obj)
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu_fp32(v) for k, v in obj.items()}
|
||||
return obj
|
||||
import kokoro.istftnet as _ist
|
||||
|
||||
def patch_decoder(model, device: str):
|
||||
decoder = getattr(model, "decoder", None)
|
||||
if decoder is None:
|
||||
raise RuntimeError("未找到 model.decoder,请 print(model) 确认实际模块名。")
|
||||
def _inverse_no_complex(self, magnitude, phase):
|
||||
device = magnitude.device
|
||||
dtype = magnitude.dtype
|
||||
|
||||
decoder.eval().to(device).float()
|
||||
try: torch.nn.utils.remove_weight_norm(decoder)
|
||||
except Exception: pass
|
||||
win_dev = torch.hann_window(self.win_length, device=device, dtype=dtype)
|
||||
|
||||
for p in decoder.parameters():
|
||||
p.requires_grad = False
|
||||
if p.dtype != torch.float32: p.data = p.data.float()
|
||||
for n, b in decoder.named_buffers():
|
||||
if b.dtype != torch.float32: setattr(decoder, n, b.float())
|
||||
real = magnitude * torch.cos(phase)
|
||||
imag = magnitude * torch.sin(phase)
|
||||
spec_ri = torch.stack([real, imag], dim=-1).contiguous() # (..., 2)
|
||||
|
||||
orig_forward = decoder.forward
|
||||
real_cpu = real.to("cpu")
|
||||
imag_cpu = imag.to("cpu")
|
||||
spec_complex_cpu = torch.complex(real_cpu, imag_cpu) # (..,) 复数张量
|
||||
win_cpu = torch.hann_window(self.win_length, device="cpu", dtype=dtype)
|
||||
|
||||
def forward_patched(*args, **kwargs):
|
||||
# 关键:decoder 运行时,强制 FP32、并保证输入与 decoder 在同一设备
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
if device == "cpu":
|
||||
args_ = _to_cpu_fp32(args); kwargs_ = _to_cpu_fp32(kwargs)
|
||||
out = orig_forward(*args_, **kwargs_)
|
||||
return _to_cpu_fp32(out)
|
||||
else: # device == "cuda"
|
||||
def to_gpu_fp32(x):
|
||||
if torch.is_tensor(x):
|
||||
return x.detach().to("cuda", dtype=torch.float32)
|
||||
if isinstance(x, (list, tuple)):
|
||||
return type(x)(to_gpu_fp32(t) for t in x)
|
||||
if isinstance(x, dict):
|
||||
return {k: to_gpu_fp32(v) for k, v in x.items()}
|
||||
return x
|
||||
args_ = to_gpu_fp32(args); kwargs_ = to_gpu_fp32(kwargs)
|
||||
out = orig_forward(*args_, **kwargs_)
|
||||
if torch.is_tensor(out):
|
||||
return out.detach().to(device, dtype=torch.float32)
|
||||
return out
|
||||
wav_cpu = torch.istft(
|
||||
spec_complex_cpu,
|
||||
n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=win_cpu,
|
||||
center=True,
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
)
|
||||
return wav_cpu.to(device).unsqueeze(-2)
|
||||
|
||||
decoder.forward = forward_patched
|
||||
def _transform_no_complex(self, input_data):
|
||||
z = torch.stft(
|
||||
input_data,
|
||||
n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=self.window.to(input_data.device, dtype=input_data.dtype),
|
||||
return_complex=False,
|
||||
center=True,
|
||||
normalized=False,
|
||||
)
|
||||
real = z[..., 0]
|
||||
imag = z[..., 1]
|
||||
magnitude = torch.sqrt(real * real + imag * imag)
|
||||
phase = torch.atan2(imag, real)
|
||||
return magnitude, phase
|
||||
|
||||
# 替换 Kokoro 的 STFT.inverse 实现
|
||||
_ist.TorchSTFT.inverse = _inverse_no_complex
|
||||
_ist.TorchSTFT.transform = _transform_no_complex
|
||||
|
||||
|
||||
def init():
|
||||
@@ -153,7 +156,6 @@ def init():
|
||||
global pipeline_dict
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval()
|
||||
patch_decoder(model, "cpu")
|
||||
|
||||
en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False)
|
||||
en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)
|
||||
|
||||
Reference in New Issue
Block a user