# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Llama4 model compatible with HuggingFace weights.""" import re from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP from .llama import LlamaMLP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) logger = init_logger(__name__) def _extract_layer_index(prefix: str) -> int: """Extract layer index from prefix string like 'model.layers.0.self_attn'.""" match = re.search(r'layers\.(\d+)', prefix) if match is None: raise ValueError(f"Cannot extract layer index from prefix: {prefix}") return int(match.group(1)) class Llama4MoE(nn.Module): """Llama4 Mixture of Experts with shared expert.""" @staticmethod def custom_routing_function( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: router_scores, router_indices = torch.topk( gating_output, topk, dim=-1) router_scores = torch.sigmoid(router_scores.float()) return (router_scores, router_indices.to(torch.int32)) def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.top_k = getattr(config, "num_experts_per_tok", 1) self.num_local_experts = getattr(config, "num_local_experts", 8) self.hidden_size = getattr(config, "hidden_size", 4096) intermediate_size_moe = getattr(config, "intermediate_size", 8192) self.router = ReplicatedLinear( self.hidden_size, self.num_local_experts, bias=False, quant_config=None, prefix=f"{prefix}.router", ) self.experts = FusedMoE( num_experts=self.num_local_experts, top_k=self.top_k, hidden_size=self.hidden_size, intermediate_size=intermediate_size_moe, reduce_results=False, renormalize=False, quant_config=quant_config, custom_routing_function=Llama4MoE.custom_routing_function, prefix=f"{prefix}.experts", ) self.shared_expert = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=intermediate_size_moe, hidden_act="silu", quant_config=quant_config, bias=False, prefix=f"{prefix}.shared_expert", ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.router(hidden_states) # routed experts routed_out = self.experts(hidden_states, router_logits) # shared expert shared_out = self.shared_expert(hidden_states) # combine and all-reduce experts_out = routed_out + shared_out if self.tp_size > 1: experts_out = tensor_model_parallel_all_reduce(experts_out) return experts_out.view(orig_shape) class Llama4Attention(nn.Module): def __init__( self, config, hidden_size: int, num_heads: int, num_kv_heads: int, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, prefix: str = "", ) -> None: super().__init__() self.layer_idx = _extract_layer_index(prefix) self.hidden_size = hidden_size self.no_rope_layers = getattr(config, "no_rope_layers", None) self.nope = (self.no_rope_layers is not None and self.no_rope_layers[self.layer_idx] == 0) self.use_qk_norm = getattr(config, "use_qk_norm", False) and not self.nope tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = getattr(config, "head_dim", self.hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings # Temperature tuning for NoPE layers self.attn_temperature_tuning = ( self.nope and getattr(config, "attn_temperature_tuning", False)) self.floor_scale = getattr(config, "floor_scale", 8192.0) self.attn_scale = getattr(config, "attn_scale", 0.1) # QK norm rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) if self.use_qk_norm: self.qk_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) # v0.6.2 RMSNorm doesn't support has_weight=False, # so we set weight to ones and make it non-trainable self.qk_norm.weight.data.fill_(1.0) self.qk_norm.weight.requires_grad = False else: self.qk_norm = None self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) # RoPE (None for NoPE layers) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if not self.nope: self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=True, ) else: self.rotary_emb = None self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, ) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 return attn_scale.unsqueeze(-1) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.rotary_emb is not None: q, k = self.rotary_emb(positions, q, k) if self.qk_norm is not None: q = q.reshape(-1, self.head_dim) q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype) k = k.reshape(-1, self.head_dim) k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype) if self.attn_temperature_tuning and self.nope: attn_scale = self._get_attn_scale(positions) q = (q * attn_scale).to(q.dtype) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output class Llama4DecoderLayer(nn.Module): def __init__( self, config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.layer_idx = _extract_layer_index(prefix) self.hidden_size = getattr(config, "hidden_size", 4096) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Llama4Attention( config=config, hidden_size=self.hidden_size, num_heads=getattr(config, "num_attention_heads", 32), num_kv_heads=getattr(config, "num_key_value_heads", getattr(config, "num_attention_heads", 32)), max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=False, cache_config=cache_config, prefix=f"{prefix}.self_attn", ) # Interleaved MoE/dense layers interleave_moe_layer_step = getattr(config, "interleave_moe_layer_step", 0) is_moe_layer = (interleave_moe_layer_step > 0 and (self.layer_idx + 1) % interleave_moe_layer_step == 0) if is_moe_layer: self.feed_forward = Llama4MoE( config=config, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) else: intermediate_size_mlp = getattr(config, "intermediate_size_mlp", getattr(config, "intermediate_size", 8192)) self.feed_forward = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=intermediate_size_mlp, hidden_act="silu", quant_config=quant_config, bias=False, prefix=f"{prefix}.feed_forward", ) rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual class Llama4Model(nn.Module): """Llama4 model - independent implementation to avoid pad_token_id issue.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config # Defensive access - Llama4Config may not have pad_token_id self.padding_idx = getattr(config, "pad_token_id", None) lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or ( getattr(config, "tie_word_embeddings", False) and get_pp_group().is_last_rank): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Llama4DecoderLayer( config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Llama4ForCausalLM(nn.Module, SupportsPP): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config # Llama4ForConditionalGeneration uses top-level Llama4Config # which has text_config sub-config. Extract it for text model. text_config = getattr(config, "text_config", None) if text_config is not None: orig_archs = getattr(config, "architectures", None) vllm_config.model_config.hf_config = text_config if orig_archs and not getattr(text_config, "architectures", None): text_config.architectures = orig_archs config = text_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config self.model = Llama4Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=( DEFAULT_VOCAB_PADDING_SIZE if not lora_config else lora_config.lora_vocab_padding_size), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if getattr(config, "tie_word_embeddings", False): self.lm_head = self.lm_head.tie_weights( self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( self.unpadded_vocab_size, config.vocab_size, logit_scale) self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds) return model_output def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def permute_qk_weight_for_rotary( self, name: str, loaded_weight: torch.Tensor, ) -> Tuple[str, torch.Tensor]: """Permute Q/K weights for rotary embedding compatibility.""" def permute(w: torch.Tensor, n_heads: int): attn_in = getattr(self.config, "head_dim", 128) * n_heads attn_out = getattr(self.config, "hidden_size", 4096) return (w.contiguous() .view(n_heads, attn_in // n_heads // 2, 2, attn_out) .transpose(1, 2).reshape(attn_in, attn_out)) modules = name.split(".") is_weight = modules[-1] == "weight" if is_weight: if "k_proj" in modules: loaded_weight = permute( loaded_weight, getattr(self.config, "num_key_value_heads", 8)) elif "q_proj" in modules: loaded_weight = permute( loaded_weight, getattr(self.config, "num_attention_heads", 32)) return name, loaded_weight def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], ): loader = AutoWeightsLoader( self, skip_prefixes=( ["lm_head."] if getattr(self.config, "tie_word_embeddings", False) else None), ) def _process_weights(weights): for name, loaded_weight in weights: # Strip language_model. prefix for Llama4ForConditionalGeneration if name.startswith("language_model."): name = name[len("language_model."):] # Skip vision encoder weights elif name.startswith("multi_modal_projector.") or \ name.startswith("vision_encoder.") or \ name.startswith("vision_model."): continue name, loaded_weight = self.permute_qk_weight_for_rotary( name, loaded_weight) yield name, loaded_weight loader.load_weights(_process_weights(weights))