Fix one more issue reported by torchfix (#4859)
This commit is contained in:
@@ -25,7 +25,7 @@ import torch.nn.functional as F
|
|||||||
import torch.nn.utils.parametrize as P
|
import torch.nn.utils.parametrize as P
|
||||||
import torch.types
|
import torch.types
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils import parametrizations
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
|
from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
@@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
|||||||
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
|
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
|
||||||
self.head_code = nn.ModuleList(
|
self.head_code = nn.ModuleList(
|
||||||
[
|
[
|
||||||
weight_norm(
|
parametrizations.weight_norm(
|
||||||
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
|
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
|
||||||
name="weight",
|
name="weight",
|
||||||
)
|
)
|
||||||
@@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel):
|
|||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# adapt to parametrization
|
# For weight_norm parametrization, handle both old and new formats
|
||||||
if self.config.init_tts and "tts" in name:
|
if self.config.init_tts and "tts" in name:
|
||||||
name = name.replace(".parametrizations", "")
|
# Handle loading from older checkpoints with weight_g/weight_v format
|
||||||
name = name.replace(".weight.original0", ".weight_g")
|
if ".weight_g" in name or ".weight_v" in name:
|
||||||
name = name.replace(".weight.original1", ".weight_v")
|
name = name.replace(
|
||||||
|
".weight_g", ".parametrizations.weight.original0"
|
||||||
|
)
|
||||||
|
name = name.replace(
|
||||||
|
".weight_v", ".parametrizations.weight.original1"
|
||||||
|
)
|
||||||
|
elif ".weight" in name and name not in params_dict:
|
||||||
|
param_name = name.replace(
|
||||||
|
".weight", ".parametrizations.weight.original0"
|
||||||
|
)
|
||||||
|
if param_name in params_dict:
|
||||||
|
name = param_name
|
||||||
|
|
||||||
# adapt to VisionAttention
|
# adapt to VisionAttention
|
||||||
if "vpm" in name:
|
if "vpm" in name:
|
||||||
|
|||||||
Reference in New Issue
Block a user