forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3 and llama4
This commit is contained in:
560
vllm-v0.6.2/vllm/model_executor/models/llama4.py
Normal file
560
vllm-v0.6.2/vllm/model_executor/models/llama4.py
Normal file
@@ -0,0 +1,560 @@
|
||||
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Llama4 model compatible with HuggingFace weights."""
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .llama import LlamaMLP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _extract_layer_index(prefix: str) -> int:
|
||||
"""Extract layer index from prefix string like 'model.layers.0.self_attn'."""
|
||||
match = re.search(r'layers\.(\d+)', prefix)
|
||||
if match is None:
|
||||
raise ValueError(f"Cannot extract layer index from prefix: {prefix}")
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
class Llama4MoE(nn.Module):
|
||||
"""Llama4 Mixture of Experts with shared expert."""
|
||||
|
||||
@staticmethod
|
||||
def custom_routing_function(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_scores, router_indices = torch.topk(
|
||||
gating_output, topk, dim=-1)
|
||||
router_scores = torch.sigmoid(router_scores.float())
|
||||
return (router_scores, router_indices.to(torch.int32))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.top_k = getattr(config, "num_experts_per_tok", 1)
|
||||
self.num_local_experts = getattr(config, "num_local_experts", 8)
|
||||
self.hidden_size = getattr(config, "hidden_size", 4096)
|
||||
intermediate_size_moe = getattr(config, "intermediate_size", 8192)
|
||||
|
||||
self.router = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.num_local_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.router",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=self.num_local_experts,
|
||||
top_k=self.top_k,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
reduce_results=False,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
# routed experts
|
||||
routed_out = self.experts(hidden_states, router_logits)
|
||||
# shared expert
|
||||
shared_out = self.shared_expert(hidden_states)
|
||||
# combine and all-reduce
|
||||
experts_out = routed_out + shared_out
|
||||
if self.tp_size > 1:
|
||||
experts_out = tensor_model_parallel_all_reduce(experts_out)
|
||||
return experts_out.view(orig_shape)
|
||||
|
||||
|
||||
class Llama4Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = _extract_layer_index(prefix)
|
||||
self.hidden_size = hidden_size
|
||||
self.no_rope_layers = getattr(config, "no_rope_layers", None)
|
||||
self.nope = (self.no_rope_layers is not None
|
||||
and self.no_rope_layers[self.layer_idx] == 0)
|
||||
self.use_qk_norm = getattr(config, "use_qk_norm", False) and not self.nope
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Temperature tuning for NoPE layers
|
||||
self.attn_temperature_tuning = (
|
||||
self.nope and getattr(config, "attn_temperature_tuning", False))
|
||||
self.floor_scale = getattr(config, "floor_scale", 8192.0)
|
||||
self.attn_scale = getattr(config, "attn_scale", 0.1)
|
||||
|
||||
# QK norm
|
||||
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
|
||||
if self.use_qk_norm:
|
||||
self.qk_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
# v0.6.2 RMSNorm doesn't support has_weight=False,
|
||||
# so we set weight to ones and make it non-trainable
|
||||
self.qk_norm.weight.data.fill_(1.0)
|
||||
self.qk_norm.weight.requires_grad = False
|
||||
else:
|
||||
self.qk_norm = None
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
# RoPE (None for NoPE layers)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if not self.nope:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=True,
|
||||
)
|
||||
else:
|
||||
self.rotary_emb = None
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
||||
return attn_scale.unsqueeze(-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
if self.rotary_emb is not None:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
if self.qk_norm is not None:
|
||||
q = q.reshape(-1, self.head_dim)
|
||||
q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
|
||||
k = k.reshape(-1, self.head_dim)
|
||||
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
|
||||
|
||||
if self.attn_temperature_tuning and self.nope:
|
||||
attn_scale = self._get_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Llama4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = _extract_layer_index(prefix)
|
||||
self.hidden_size = getattr(config, "hidden_size", 4096)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
|
||||
self.self_attn = Llama4Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=getattr(config, "num_attention_heads", 32),
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
getattr(config, "num_attention_heads", 32)),
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# Interleaved MoE/dense layers
|
||||
interleave_moe_layer_step = getattr(config,
|
||||
"interleave_moe_layer_step", 0)
|
||||
is_moe_layer = (interleave_moe_layer_step > 0
|
||||
and (self.layer_idx + 1)
|
||||
% interleave_moe_layer_step == 0)
|
||||
|
||||
if is_moe_layer:
|
||||
self.feed_forward = Llama4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
intermediate_size_mlp = getattr(config, "intermediate_size_mlp",
|
||||
getattr(config,
|
||||
"intermediate_size", 8192))
|
||||
self.feed_forward = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=intermediate_size_mlp,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
|
||||
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
|
||||
self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(self.hidden_size,
|
||||
eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Llama4Model(nn.Module):
|
||||
"""Llama4 model - independent implementation to avoid pad_token_id issue."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
# Defensive access - Llama4Config may not have pad_token_id
|
||||
self.padding_idx = getattr(config, "pad_token_id", None)
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank or (
|
||||
getattr(config, "tie_word_embeddings", False)
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Llama4DecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Llama4ForCausalLM(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = Llama4Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE if not lora_config
|
||||
else lora_config.lora_vocab_padding_size),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if getattr(config, "tie_word_embeddings", False):
|
||||
self.lm_head = self.lm_head.tie_weights(
|
||||
self.model.embed_tokens)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def permute_qk_weight_for_rotary(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> Tuple[str, torch.Tensor]:
|
||||
"""Permute Q/K weights for rotary embedding compatibility."""
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
attn_in = getattr(self.config, "head_dim", 128) * n_heads
|
||||
attn_out = getattr(self.config, "hidden_size", 4096)
|
||||
return (w.contiguous()
|
||||
.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2).reshape(attn_in, attn_out))
|
||||
|
||||
modules = name.split(".")
|
||||
is_weight = modules[-1] == "weight"
|
||||
|
||||
if is_weight:
|
||||
if "k_proj" in modules:
|
||||
loaded_weight = permute(
|
||||
loaded_weight,
|
||||
getattr(self.config, "num_key_value_heads", 8))
|
||||
elif "q_proj" in modules:
|
||||
loaded_weight = permute(
|
||||
loaded_weight,
|
||||
getattr(self.config, "num_attention_heads", 32))
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(
|
||||
["lm_head."]
|
||||
if getattr(self.config, "tie_word_embeddings", False)
|
||||
else None),
|
||||
)
|
||||
weights = [
|
||||
self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
]
|
||||
loader.load_weights(weights)
|
||||
@@ -65,6 +65,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
|
||||
Reference in New Issue
Block a user