update kokoro fix

This commit is contained in:
2025-09-12 15:42:17 +08:00
parent 4eba0d1486
commit 18aa9001d8
3 changed files with 48 additions and 47 deletions

View File

@@ -118,7 +118,7 @@ else:
| 模型名称 | 模型类型 | 适配状态 | 沐曦 MetaX C500运行时间/s | Nvidia A100运行时间/s |
| ---------- | ---------------------- | -------- | ----------------- | --------------------- |
| kokoro | StyleTTS 2, ISTFTNet | 成功 | 202.6 | 5.4 |
| kokoro | StyleTTS 2, ISTFTNet | 成功 | 4.3 | 5.4 |
| f5-TTS | DiT, ConvNeXt V2 | 成功 | 7.1 | 5.4 |
| gpt-sovits | VITS | 成功 | 24.4 | 20.5 |
| matcha | OT-CFM, Transformer | 成功 | 2.9 | 3.2 |

View File

@@ -42,5 +42,4 @@ curl --request POST "http://localhost:8080/tts" \
---
- Patch 能运行,但是生成的音频会有噪声
- Patch: 将 decoder/istftnet 固定为 CPU FP32禁用 AMP/TF32修复 GPU “打字机”噪声
- Patch: 修复torch.istft 复数运算出错问题 ,修复 GPU “打字机”噪声

View File

@@ -16,6 +16,10 @@ from scipy.signal import resample
import torch
from torch import Tensor
from torch.nn import functional as F
import torch.nn as nn
import types
from typing import Optional, List
import re
from dataclasses import dataclass
@@ -95,56 +99,55 @@ def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int):
# ================== decoder/istftnet 补丁 ==================
def _to_cpu_fp32(obj):
if torch.is_tensor(obj):
return obj.detach().to("cpu", dtype=torch.float32)
if isinstance(obj, (list, tuple)):
return type(obj)(_to_cpu_fp32(x) for x in obj)
if isinstance(obj, dict):
return {k: _to_cpu_fp32(v) for k, v in obj.items()}
return obj
import kokoro.istftnet as _ist
def patch_decoder(model, device: str):
decoder = getattr(model, "decoder", None)
if decoder is None:
raise RuntimeError("未找到 model.decoder请 print(model) 确认实际模块名。")
def _inverse_no_complex(self, magnitude, phase):
device = magnitude.device
dtype = magnitude.dtype
decoder.eval().to(device).float()
try: torch.nn.utils.remove_weight_norm(decoder)
except Exception: pass
win_dev = torch.hann_window(self.win_length, device=device, dtype=dtype)
for p in decoder.parameters():
p.requires_grad = False
if p.dtype != torch.float32: p.data = p.data.float()
for n, b in decoder.named_buffers():
if b.dtype != torch.float32: setattr(decoder, n, b.float())
real = magnitude * torch.cos(phase)
imag = magnitude * torch.sin(phase)
spec_ri = torch.stack([real, imag], dim=-1).contiguous() # (..., 2)
orig_forward = decoder.forward
real_cpu = real.to("cpu")
imag_cpu = imag.to("cpu")
spec_complex_cpu = torch.complex(real_cpu, imag_cpu) # (..,) 复数张量
win_cpu = torch.hann_window(self.win_length, device="cpu", dtype=dtype)
def forward_patched(*args, **kwargs):
# 关键decoder 运行时,强制 FP32、并保证输入与 decoder 在同一设备
with torch.amp.autocast('cuda', enabled=False):
if device == "cpu":
args_ = _to_cpu_fp32(args); kwargs_ = _to_cpu_fp32(kwargs)
out = orig_forward(*args_, **kwargs_)
return _to_cpu_fp32(out)
else: # device == "cuda"
def to_gpu_fp32(x):
if torch.is_tensor(x):
return x.detach().to("cuda", dtype=torch.float32)
if isinstance(x, (list, tuple)):
return type(x)(to_gpu_fp32(t) for t in x)
if isinstance(x, dict):
return {k: to_gpu_fp32(v) for k, v in x.items()}
return x
args_ = to_gpu_fp32(args); kwargs_ = to_gpu_fp32(kwargs)
out = orig_forward(*args_, **kwargs_)
if torch.is_tensor(out):
return out.detach().to(device, dtype=torch.float32)
return out
wav_cpu = torch.istft(
spec_complex_cpu,
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=win_cpu,
center=True,
normalized=False,
onesided=True,
)
return wav_cpu.to(device).unsqueeze(-2)
decoder.forward = forward_patched
def _transform_no_complex(self, input_data):
z = torch.stft(
input_data,
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window.to(input_data.device, dtype=input_data.dtype),
return_complex=False,
center=True,
normalized=False,
)
real = z[..., 0]
imag = z[..., 1]
magnitude = torch.sqrt(real * real + imag * imag)
phase = torch.atan2(imag, real)
return magnitude, phase
# 替换 Kokoro 的 STFT.inverse 实现
_ist.TorchSTFT.inverse = _inverse_no_complex
_ist.TorchSTFT.transform = _transform_no_complex
def init():
@@ -153,7 +156,6 @@ def init():
global pipeline_dict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval()
patch_decoder(model, "cpu")
en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False)
en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)