diff --git a/python/sglang/srt/models/minicpmo.py b/python/sglang/srt/models/minicpmo.py index ad89dcea2..3a0edf52b 100644 --- a/python/sglang/srt/models/minicpmo.py +++ b/python/sglang/srt/models/minicpmo.py @@ -25,7 +25,7 @@ import torch.nn.functional as F import torch.nn.utils.parametrize as P import torch.types from torch import nn -from torch.nn.utils import weight_norm +from torch.nn.utils import parametrizations from tqdm import tqdm from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN @@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel): self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) self.head_code = nn.ModuleList( [ - weight_norm( + parametrizations.weight_norm( nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), name="weight", ) @@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel): # the checkpoint. Skip them. continue - # adapt to parametrization + # For weight_norm parametrization, handle both old and new formats if self.config.init_tts and "tts" in name: - name = name.replace(".parametrizations", "") - name = name.replace(".weight.original0", ".weight_g") - name = name.replace(".weight.original1", ".weight_v") + # Handle loading from older checkpoints with weight_g/weight_v format + if ".weight_g" in name or ".weight_v" in name: + name = name.replace( + ".weight_g", ".parametrizations.weight.original0" + ) + name = name.replace( + ".weight_v", ".parametrizations.weight.original1" + ) + elif ".weight" in name and name not in params_dict: + param_name = name.replace( + ".weight", ".parametrizations.weight.original0" + ) + if param_name in params_dict: + name = param_name # adapt to VisionAttention if "vpm" in name: