From b868526d94f7729c495b6582ce71226ebba008d9 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Sun, 20 Apr 2025 20:49:27 -0400 Subject: [PATCH] Fix one more issue reported by torchfix (#4859) --- python/sglang/srt/models/minicpmo.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/minicpmo.py b/python/sglang/srt/models/minicpmo.py index ad89dcea2..3a0edf52b 100644 --- a/python/sglang/srt/models/minicpmo.py +++ b/python/sglang/srt/models/minicpmo.py @@ -25,7 +25,7 @@ import torch.nn.functional as F import torch.nn.utils.parametrize as P import torch.types from torch import nn -from torch.nn.utils import weight_norm +from torch.nn.utils import parametrizations from tqdm import tqdm from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN @@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel): self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) self.head_code = nn.ModuleList( [ - weight_norm( + parametrizations.weight_norm( nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), name="weight", ) @@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel): # the checkpoint. Skip them. continue - # adapt to parametrization + # For weight_norm parametrization, handle both old and new formats if self.config.init_tts and "tts" in name: - name = name.replace(".parametrizations", "") - name = name.replace(".weight.original0", ".weight_g") - name = name.replace(".weight.original1", ".weight_v") + # Handle loading from older checkpoints with weight_g/weight_v format + if ".weight_g" in name or ".weight_v" in name: + name = name.replace( + ".weight_g", ".parametrizations.weight.original0" + ) + name = name.replace( + ".weight_v", ".parametrizations.weight.original1" + ) + elif ".weight" in name and name not in params_dict: + param_name = name.replace( + ".weight", ".parametrizations.weight.original0" + ) + if param_name in params_dict: + name = param_name # adapt to VisionAttention if "vpm" in name: