Files

657 lines
23 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2026-01-09 13:34:11 +08:00
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
2026-01-19 10:38:50 +08:00
2026-01-09 13:34:11 +08:00
import math
2026-01-19 10:38:50 +08:00
from collections.abc import Iterable
from itertools import islice
from typing import Any
2026-01-09 13:34:11 +08:00
import torch
from torch import nn
2026-01-19 10:38:50 +08:00
from transformers import PretrainedConfig
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
2026-01-09 13:34:11 +08:00
from vllm.model_executor.layers.layernorm import RMSNorm
2026-01-19 10:38:50 +08:00
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
2026-01-09 13:34:11 +08:00
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2026-01-19 10:38:50 +08:00
from vllm.model_executor.layers.quantization import QuantizationConfig
2026-01-09 13:34:11 +08:00
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
2026-01-19 10:38:50 +08:00
ParallelLMHead,
VocabParallelEmbedding,
)
2026-01-09 13:34:11 +08:00
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
2026-01-19 10:38:50 +08:00
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
2026-01-09 13:34:11 +08:00
class MiniCPMMoE(nn.Module):
"""A tensor-parallel MoE implementation that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
2026-01-19 10:38:50 +08:00
params_dtype: torch.dtype | None = None,
tp_size: int | None = None,
2026-01-09 13:34:11 +08:00
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
2026-01-19 10:38:50 +08:00
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None,
)
2026-01-09 13:34:11 +08:00
self.ws = nn.Parameter(
2026-01-19 10:38:50 +08:00
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device=current_platform.device_type,
dtype=self.params_dtype,
)
)
2026-01-09 13:34:11 +08:00
self.w2s = nn.Parameter(
2026-01-19 10:38:50 +08:00
torch.empty(
self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device=current_platform.device_type,
dtype=self.params_dtype,
)
)
set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
):
2026-01-09 13:34:11 +08:00
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
2026-01-19 10:38:50 +08:00
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard, :
]
2026-01-09 13:34:11 +08:00
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
2026-01-19 10:38:50 +08:00
topk_weights, topk_ids, _ = fused_topk(
hidden_states, router_logits, self.top_k, renormalize=True
)
final_hidden_states = fused_experts(
hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=True
)
2026-01-09 13:34:11 +08:00
if self.tp_size > 1:
2026-01-19 10:38:50 +08:00
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
2026-01-09 13:34:11 +08:00
return final_hidden_states.view(num_tokens, hidden_size)
class MiniCPMMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
2026-01-19 10:38:50 +08:00
hidden_act_param: float,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
2026-01-09 13:34:11 +08:00
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
2026-01-19 10:38:50 +08:00
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
2026-01-09 13:34:11 +08:00
bias=False,
2026-01-19 10:38:50 +08:00
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act == "silu":
self.act_fn = SiluAndMul()
elif hidden_act == "fatrelu":
self.act_fn = FatreluAndMul(threshold=hidden_act_param)
else:
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu and fatrelu are supported for now."
)
2026-01-09 13:34:11 +08:00
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MiniCPMAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
2026-01-19 10:38:50 +08:00
rope_parameters: dict[str, Any] | None = None,
2026-01-09 13:34:11 +08:00
max_position_embeddings: int = 8192,
2026-01-19 10:38:50 +08:00
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
2026-01-09 13:34:11 +08:00
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.qkv_proj",
2026-01-09 13:34:11 +08:00
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
2026-01-19 10:38:50 +08:00
prefix=f"{prefix}.o_proj",
2026-01-09 13:34:11 +08:00
)
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position_embeddings,
2026-01-19 10:38:50 +08:00
rope_parameters=rope_parameters,
2026-01-09 13:34:11 +08:00
)
2026-01-19 10:38:50 +08:00
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
2026-01-09 13:34:11 +08:00
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
orig_dtype = q.dtype
q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype)
2026-01-19 10:38:50 +08:00
attn_output = self.attn(q, k, v)
2026-01-09 13:34:11 +08:00
output, _ = self.o_proj(attn_output)
return output
class MiniCPMDecoderLayer(nn.Module):
def __init__(
self,
2026-01-19 10:38:50 +08:00
config: PretrainedConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
2026-01-09 13:34:11 +08:00
) -> None:
super().__init__()
self.config = config
2026-01-19 10:38:50 +08:00
self.cache_config = cache_config
self.quant_config = quant_config
2026-01-09 13:34:11 +08:00
self.hidden_size = config.hidden_size
2026-01-19 10:38:50 +08:00
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.prefix = prefix
self._init_attn_block()
self._init_ffn_block()
def _init_attn_block(self):
self.input_layernorm = RMSNorm(
self.config.hidden_size, eps=self.config.rms_norm_eps
)
2026-01-09 13:34:11 +08:00
self.self_attn = MiniCPMAttention(
hidden_size=self.hidden_size,
2026-01-19 10:38:50 +08:00
num_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
rope_parameters=self.config.rope_parameters,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
)
def _init_ffn_block(self):
self.post_attention_layernorm = RMSNorm(
self.config.hidden_size, eps=self.config.rms_norm_eps
2026-01-09 13:34:11 +08:00
)
self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0:
self.mlp = MiniCPMMLP(
hidden_size=self.hidden_size,
2026-01-19 10:38:50 +08:00
intermediate_size=self.config.intermediate_size,
hidden_act=self.config.hidden_act,
hidden_act_param=getattr(self.config, "hidden_act_param", 0.0),
quant_config=self.quant_config,
2026-01-09 13:34:11 +08:00
)
else:
2026-01-19 10:38:50 +08:00
self.mlp = MiniCPMMoE(
num_experts=self.config.num_experts,
top_k=self.config.num_experts_per_tok,
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
)
2026-01-09 13:34:11 +08:00
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
2026-01-19 10:38:50 +08:00
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
2026-01-09 13:34:11 +08:00
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
2026-01-19 10:38:50 +08:00
hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
)
2026-01-09 13:34:11 +08:00
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
2026-01-19 10:38:50 +08:00
hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
)
2026-01-09 13:34:11 +08:00
return hidden_states, None
2026-01-19 10:38:50 +08:00
@support_torch_compile
2026-01-09 13:34:11 +08:00
class MiniCPMModel(nn.Module):
2026-01-19 10:38:50 +08:00
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
2026-01-09 13:34:11 +08:00
super().__init__()
2026-01-19 10:38:50 +08:00
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
2026-01-09 13:34:11 +08:00
self.config = config
2026-01-19 10:38:50 +08:00
self.cache_config = cache_config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
2026-01-09 13:34:11 +08:00
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
2026-01-19 10:38:50 +08:00
self.num_experts = getattr(self.config, "num_experts", 0)
self._init_layers(prefix, config, cache_config, quant_config)
2026-01-09 13:34:11 +08:00
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2026-01-19 10:38:50 +08:00
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size
)
def _init_layers(
self,
prefix: str,
config: PretrainedConfig,
cache_config: CacheConfig | None,
quant_config: QuantizationConfig | None,
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(
config, cache_config, quant_config, prefix=prefix
),
prefix=f"{prefix}.layers",
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
2026-01-09 13:34:11 +08:00
embedding = self.embed_tokens(input_ids)
return embedding * self.config.scale_emb
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
2026-01-19 10:38:50 +08:00
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
2026-01-09 13:34:11 +08:00
else:
2026-01-19 10:38:50 +08:00
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
2026-01-09 13:34:11 +08:00
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
2026-01-19 10:38:50 +08:00
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
2026-01-09 13:34:11 +08:00
)
2026-01-19 10:38:50 +08:00
hidden_states = self.norm(hidden_states)
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
2026-01-09 13:34:11 +08:00
return hidden_states
2026-01-19 10:38:50 +08:00
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
2026-01-09 13:34:11 +08:00
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
2026-01-19 10:38:50 +08:00
(
"ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
2026-01-09 13:34:11 +08:00
for expert_id in range(self.num_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
2026-01-19 10:38:50 +08:00
loaded_params: set[str] = set()
2026-01-09 13:34:11 +08:00
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
2026-01-19 10:38:50 +08:00
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
2026-01-09 13:34:11 +08:00
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
2026-01-19 10:38:50 +08:00
for param_name, weight_name, shard_id in stacked_params_mapping:
2026-01-09 13:34:11 +08:00
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
2026-01-19 10:38:50 +08:00
if is_pp_missing_parameter(name, self):
continue
2026-01-09 13:34:11 +08:00
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
2026-01-19 10:38:50 +08:00
if is_pp_missing_parameter(name, self):
continue
2026-01-09 13:34:11 +08:00
param = params_dict[name]
weight_loader = param.weight_loader
2026-01-19 10:38:50 +08:00
weight_loader(
param, loaded_weight, weight_name, expert_id=expert_id
)
2026-01-09 13:34:11 +08:00
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
2026-01-19 10:38:50 +08:00
if is_pp_missing_parameter(name, self):
continue
2026-01-09 13:34:11 +08:00
param = params_dict[name]
2026-01-19 10:38:50 +08:00
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
2026-01-09 13:34:11 +08:00
weight_loader(param, loaded_weight)
2026-01-19 10:38:50 +08:00
loaded_params.add(name)
return loaded_params
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.prefix = prefix
self.vllm_config = vllm_config
self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.model = self._init_model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
if parallel_config.enable_eplb and getattr(config, "num_experts", 0) > 0:
raise NotImplementedError("EPLB is not supported for MiniCPM yet.")
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
model_output = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
if isinstance(model_output, tuple) and len(model_output) == 2:
# Aux hidden states are present.
hidden_states, aux_hidden_states = model_output
hidden_states = hidden_states / self.scale_width
return hidden_states, aux_hidden_states
else:
# Only hidden states or IntermediateTensors
if isinstance(model_output, IntermediateTensors):
return model_output
else:
hidden_states = model_output / self.scale_width
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)