diff --git a/README.md b/README.md index 1f939eb..cef408a 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ else: | 模型名称 | 模型类型 | 适配状态 | 沐曦 MetaX C500运行时间/s | Nvidia A100运行时间/s | | ---------- | ---------------------- | -------- | ----------------- | --------------------- | -| kokoro | StyleTTS 2, ISTFTNet | 成功 | 202.6 | 5.4 | +| kokoro | StyleTTS 2, ISTFTNet | 成功 | 4.3 | 5.4 | | f5-TTS | DiT, ConvNeXt V2 | 成功 | 7.1 | 5.4 | | gpt-sovits | VITS | 成功 | 24.4 | 20.5 | | matcha | OT-CFM, Transformer | 成功 | 2.9 | 3.2 | diff --git a/metaX-C500-kokoro/README.md b/metaX-C500-kokoro/README.md index 4afdda2..5a87470 100644 --- a/metaX-C500-kokoro/README.md +++ b/metaX-C500-kokoro/README.md @@ -42,5 +42,4 @@ curl --request POST "http://localhost:8080/tts" \ --- -- 无 Patch 能运行,但是生成的音频会有噪声 -- Patch: 将 decoder/istftnet 固定为 CPU FP32(禁用 AMP/TF32),修复 GPU “打字机”噪声 +- Patch: 修复torch.istft 复数运算出错问题 ,修复 GPU “打字机”噪声 diff --git a/metaX-C500-kokoro/kokoro_server.py b/metaX-C500-kokoro/kokoro_server.py index d1e980e..e3a4e1b 100644 --- a/metaX-C500-kokoro/kokoro_server.py +++ b/metaX-C500-kokoro/kokoro_server.py @@ -16,6 +16,10 @@ from scipy.signal import resample import torch from torch import Tensor from torch.nn import functional as F + +import torch.nn as nn +import types + from typing import Optional, List import re from dataclasses import dataclass @@ -95,56 +99,55 @@ def audio_postprocess(data: np.ndarray, original_rate: int, target_rate: int): # ================== decoder/istftnet 补丁 ================== -def _to_cpu_fp32(obj): - if torch.is_tensor(obj): - return obj.detach().to("cpu", dtype=torch.float32) - if isinstance(obj, (list, tuple)): - return type(obj)(_to_cpu_fp32(x) for x in obj) - if isinstance(obj, dict): - return {k: _to_cpu_fp32(v) for k, v in obj.items()} - return obj +import kokoro.istftnet as _ist -def patch_decoder(model, device: str): - decoder = getattr(model, "decoder", None) - if decoder is None: - raise RuntimeError("未找到 model.decoder,请 print(model) 确认实际模块名。") +def _inverse_no_complex(self, magnitude, phase): + device = magnitude.device + dtype = magnitude.dtype - decoder.eval().to(device).float() - try: torch.nn.utils.remove_weight_norm(decoder) - except Exception: pass + win_dev = torch.hann_window(self.win_length, device=device, dtype=dtype) - for p in decoder.parameters(): - p.requires_grad = False - if p.dtype != torch.float32: p.data = p.data.float() - for n, b in decoder.named_buffers(): - if b.dtype != torch.float32: setattr(decoder, n, b.float()) + real = magnitude * torch.cos(phase) + imag = magnitude * torch.sin(phase) + spec_ri = torch.stack([real, imag], dim=-1).contiguous() # (..., 2) - orig_forward = decoder.forward + real_cpu = real.to("cpu") + imag_cpu = imag.to("cpu") + spec_complex_cpu = torch.complex(real_cpu, imag_cpu) # (..,) 复数张量 + win_cpu = torch.hann_window(self.win_length, device="cpu", dtype=dtype) - def forward_patched(*args, **kwargs): - # 关键:decoder 运行时,强制 FP32、并保证输入与 decoder 在同一设备 - with torch.amp.autocast('cuda', enabled=False): - if device == "cpu": - args_ = _to_cpu_fp32(args); kwargs_ = _to_cpu_fp32(kwargs) - out = orig_forward(*args_, **kwargs_) - return _to_cpu_fp32(out) - else: # device == "cuda" - def to_gpu_fp32(x): - if torch.is_tensor(x): - return x.detach().to("cuda", dtype=torch.float32) - if isinstance(x, (list, tuple)): - return type(x)(to_gpu_fp32(t) for t in x) - if isinstance(x, dict): - return {k: to_gpu_fp32(v) for k, v in x.items()} - return x - args_ = to_gpu_fp32(args); kwargs_ = to_gpu_fp32(kwargs) - out = orig_forward(*args_, **kwargs_) - if torch.is_tensor(out): - return out.detach().to(device, dtype=torch.float32) - return out + wav_cpu = torch.istft( + spec_complex_cpu, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=win_cpu, + center=True, + normalized=False, + onesided=True, + ) + return wav_cpu.to(device).unsqueeze(-2) - decoder.forward = forward_patched +def _transform_no_complex(self, input_data): + z = torch.stft( + input_data, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window.to(input_data.device, dtype=input_data.dtype), + return_complex=False, + center=True, + normalized=False, + ) + real = z[..., 0] + imag = z[..., 1] + magnitude = torch.sqrt(real * real + imag * imag) + phase = torch.atan2(imag, real) + return magnitude, phase +# 替换 Kokoro 的 STFT.inverse 实现 +_ist.TorchSTFT.inverse = _inverse_no_complex +_ist.TorchSTFT.transform = _transform_no_complex def init(): @@ -153,7 +156,6 @@ def init(): global pipeline_dict device = 'cuda' if torch.cuda.is_available() else 'cpu' model = KModel(repo_id=repo_id, model=os.path.join(model_dir, model_name), config=os.path.join(model_dir, 'config.json')).to(device).eval() - patch_decoder(model, "cpu") en_empty_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=False) en_pipeline = KPipeline(lang_code='a', repo_id=repo_id, model=model)