Fix one more issue reported by torchfix (#4859)

This commit is contained in:
Brayden Zhong
2025-04-20 20:49:27 -04:00
committed by GitHub
parent 502524e2da
commit b868526d94

View File

@@ -25,7 +25,7 @@ import torch.nn.functional as F
import torch.nn.utils.parametrize as P import torch.nn.utils.parametrize as P
import torch.types import torch.types
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils import parametrizations
from tqdm import tqdm from tqdm import tqdm
from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
@@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel):
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
self.head_code = nn.ModuleList( self.head_code = nn.ModuleList(
[ [
weight_norm( parametrizations.weight_norm(
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
name="weight", name="weight",
) )
@@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel):
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# adapt to parametrization # For weight_norm parametrization, handle both old and new formats
if self.config.init_tts and "tts" in name: if self.config.init_tts and "tts" in name:
name = name.replace(".parametrizations", "") # Handle loading from older checkpoints with weight_g/weight_v format
name = name.replace(".weight.original0", ".weight_g") if ".weight_g" in name or ".weight_v" in name:
name = name.replace(".weight.original1", ".weight_v") name = name.replace(
".weight_g", ".parametrizations.weight.original0"
)
name = name.replace(
".weight_v", ".parametrizations.weight.original1"
)
elif ".weight" in name and name not in params_dict:
param_name = name.replace(
".weight", ".parametrizations.weight.original0"
)
if param_name in params_dict:
name = param_name
# adapt to VisionAttention # adapt to VisionAttention
if "vpm" in name: if "vpm" in name: