diff --git a/vllm-v0.6.2/vllm/model_executor/models/llama4.py b/vllm-v0.6.2/vllm/model_executor/models/llama4.py new file mode 100644 index 0000000..4ec2076 --- /dev/null +++ b/vllm-v0.6.2/vllm/model_executor/models/llama4.py @@ -0,0 +1,560 @@ +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Llama4 model compatible with HuggingFace weights.""" +import re +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .llama import LlamaMLP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +def _extract_layer_index(prefix: str) -> int: + """Extract layer index from prefix string like 'model.layers.0.self_attn'.""" + match = re.search(r'layers\.(\d+)', prefix) + if match is None: + raise ValueError(f"Cannot extract layer index from prefix: {prefix}") + return int(match.group(1)) + + +class Llama4MoE(nn.Module): + """Llama4 Mixture of Experts with shared expert.""" + + @staticmethod + def custom_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + router_scores, router_indices = torch.topk( + gating_output, topk, dim=-1) + router_scores = torch.sigmoid(router_scores.float()) + return (router_scores, router_indices.to(torch.int32)) + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.top_k = getattr(config, "num_experts_per_tok", 1) + self.num_local_experts = getattr(config, "num_local_experts", 8) + self.hidden_size = getattr(config, "hidden_size", 4096) + intermediate_size_moe = getattr(config, "intermediate_size", 8192) + + self.router = ReplicatedLinear( + self.hidden_size, + self.num_local_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.router", + ) + + self.experts = FusedMoE( + num_experts=self.num_local_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=intermediate_size_moe, + reduce_results=False, + renormalize=False, + quant_config=quant_config, + custom_routing_function=Llama4MoE.custom_routing_function, + prefix=f"{prefix}.experts", + ) + + self.shared_expert = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.shared_expert", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.router(hidden_states) + # routed experts + routed_out = self.experts(hidden_states, router_logits) + # shared expert + shared_out = self.shared_expert(hidden_states) + # combine and all-reduce + experts_out = routed_out + shared_out + if self.tp_size > 1: + experts_out = tensor_model_parallel_all_reduce(experts_out) + return experts_out.view(orig_shape) + + +class Llama4Attention(nn.Module): + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = _extract_layer_index(prefix) + self.hidden_size = hidden_size + self.no_rope_layers = getattr(config, "no_rope_layers", None) + self.nope = (self.no_rope_layers is not None + and self.no_rope_layers[self.layer_idx] == 0) + self.use_qk_norm = getattr(config, "use_qk_norm", False) and not self.nope + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + + # Temperature tuning for NoPE layers + self.attn_temperature_tuning = ( + self.nope and getattr(config, "attn_temperature_tuning", False)) + self.floor_scale = getattr(config, "floor_scale", 8192.0) + self.attn_scale = getattr(config, "attn_scale", 0.1) + + # QK norm + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + if self.use_qk_norm: + self.qk_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + # v0.6.2 RMSNorm doesn't support has_weight=False, + # so we set weight to ones and make it non-trainable + self.qk_norm.weight.data.fill_(1.0) + self.qk_norm.weight.requires_grad = False + else: + self.qk_norm = None + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # RoPE (None for NoPE layers) + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if not self.nope: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + else: + self.rotary_emb = None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: + floor = torch.floor((positions + 1.0) / self.floor_scale) + attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 + return attn_scale.unsqueeze(-1) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + if self.rotary_emb is not None: + q, k = self.rotary_emb(positions, q, k) + + if self.qk_norm is not None: + q = q.reshape(-1, self.head_dim) + q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype) + k = k.reshape(-1, self.head_dim) + k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype) + + if self.attn_temperature_tuning and self.nope: + attn_scale = self._get_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Llama4DecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = _extract_layer_index(prefix) + self.hidden_size = getattr(config, "hidden_size", 4096) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = Llama4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=getattr(config, "num_attention_heads", 32), + num_kv_heads=getattr(config, "num_key_value_heads", + getattr(config, "num_attention_heads", 32)), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + # Interleaved MoE/dense layers + interleave_moe_layer_step = getattr(config, + "interleave_moe_layer_step", 0) + is_moe_layer = (interleave_moe_layer_step > 0 + and (self.layer_idx + 1) + % interleave_moe_layer_step == 0) + + if is_moe_layer: + self.feed_forward = Llama4MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + intermediate_size_mlp = getattr(config, "intermediate_size_mlp", + getattr(config, + "intermediate_size", 8192)) + self.feed_forward = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size_mlp, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.feed_forward", + ) + + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(self.hidden_size, + eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class Llama4Model(nn.Module): + """Llama4 model - independent implementation to avoid pad_token_id issue.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + # Defensive access - Llama4Config may not have pad_token_id + self.padding_idx = getattr(config, "pad_token_id", None) + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or ( + getattr(config, "tie_word_embeddings", False) + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Llama4DecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Llama4ForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.model = Llama4Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE if not lora_config + else lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if getattr(config, "tie_word_embeddings", False): + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def permute_qk_weight_for_rotary( + self, + name: str, + loaded_weight: torch.Tensor, + ) -> Tuple[str, torch.Tensor]: + """Permute Q/K weights for rotary embedding compatibility.""" + def permute(w: torch.Tensor, n_heads: int): + attn_in = getattr(self.config, "head_dim", 128) * n_heads + attn_out = getattr(self.config, "hidden_size", 4096) + return (w.contiguous() + .view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2).reshape(attn_in, attn_out)) + + modules = name.split(".") + is_weight = modules[-1] == "weight" + + if is_weight: + if "k_proj" in modules: + loaded_weight = permute( + loaded_weight, + getattr(self.config, "num_key_value_heads", 8)) + elif "q_proj" in modules: + loaded_weight = permute( + loaded_weight, + getattr(self.config, "num_attention_heads", 32)) + + return name, loaded_weight + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], + ): + loader = AutoWeightsLoader( + self, + skip_prefixes=( + ["lm_head."] + if getattr(self.config, "tie_word_embeddings", False) + else None), + ) + weights = [ + self.permute_qk_weight_for_rotary(name, loaded_weight) + for name, loaded_weight in weights + ] + loader.load_weights(weights) diff --git a/vllm-v0.6.2/vllm/model_executor/models/registry.py b/vllm-v0.6.2/vllm/model_executor/models/registry.py index 75fc5d2..b805401 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/registry.py +++ b/vllm-v0.6.2/vllm/model_executor/models/registry.py @@ -65,6 +65,7 @@ _TEXT_GENERATION_MODELS = { "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py index a75fb3d..82932ff 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py @@ -582,15 +582,24 @@ def unified_flash_attention_v2( else: # unpaged (linear cache) path if use_mla: - # MLA cache 是 2D (total_slots, head_dim), - # 不能用 reshape_paged_cache(期望 4D),直接索引写入 + # MLA: 镜像 paged 路径的处理方式 + # key_cache: (num_blocks, 1, block_size, 576) + value_to_cache = None if attn_metadata.prefill_metadata: # MLA prefill cache 已在 forward_prefill 中写入,跳过 pass else: - # key: (num_tokens, 1, head_dim) → squeeze → (num_tokens, head_dim) - # key_cache: (total_slots, head_dim) - key_cache[updated_slot_mapping.flatten()] = key.squeeze(1) + if kv_cache_dtype == 'int8': + mlu_ops.quant_to_paged_cache( + key, value_to_cache, + key_cache, value_cache, + key_cache_scale, value_cache_scale, + updated_slot_mapping.flatten()) + else: + mlu_ops.reshape_paged_cache( + key, value_to_cache, + key_cache, value_cache, + updated_slot_mapping.flatten()) else: # FIXME: After TMO-1496 is completed, remove this code. if key.stride() != value.stride(): diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py index 8377054..5f1dc07 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py @@ -37,7 +37,7 @@ def vllm__config__CacheConfig___verify_cache_dtype(self) -> None: def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" - if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type == 'deepseek_v2': + if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'): # feature flag MLA return 1 total_num_kv_heads = self.get_total_num_kv_heads() @@ -51,7 +51,7 @@ def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "Parallel def vllm__config__ModelConfig__get_head_size(self) -> int: # TODO remove hard code if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type == 'deepseek_v2': + ) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'): ''' ============================= Modify by vllm_mlu @@ -109,7 +109,7 @@ def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: Model def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool: result = hasattr( self.hf_text_config, - "model_type") and self.hf_text_config.model_type == 'deepseek_v2' + "model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3') return result MluHijackObject.apply_hijack(ModelConfig, diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/__init__.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/__init__.py index e38d8ee..979a66f 100755 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/__init__.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/__init__.py @@ -39,3 +39,9 @@ try: except ImportError as e: import logging logging.warning(f"Failed to import mllama hijack: {e}") + +try: + import vllm_mlu.model_executor.models.llama4 +except ImportError as e: + import logging + logging.warning(f"Failed to import llama4 hijack: {e}") diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py new file mode 100644 index 0000000..eb6bd16 --- /dev/null +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py @@ -0,0 +1,485 @@ +import torch +import re + +from typing import List, Optional, Tuple, Union, Iterable +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm_mlu.model_executor.layers.feed_forward import FeedForward +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama4 import ( + Llama4Attention, Llama4DecoderLayer, Llama4ForCausalLM, + Llama4Model, Llama4MoE) +from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp +from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm.sequence import IntermediateTensors + +from vllm_mlu.model_executor.models.layer_utils import ( + decoder_layer_forward_base, decoder_model_forward_base_pp, + is_per_tensor_smoothquant, is_per_token_smoothquant, + quant_fusion_with_rmsnorm) + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# ============================================================ +# Llama4MoE MLU replacement: SparseMoeMlp + shared expert +# ============================================================ + +class Llama4MoEMlu(SparseMoeMlp): + """MLU replacement for Llama4MoE using SparseMoeMlp + shared expert.""" + + def __init__(self, config, quant_config=None, prefix=""): + num_local_experts = getattr(config, "num_local_experts", 8) + top_k = getattr(config, "num_experts_per_tok", 1) + hidden_size = getattr(config, "hidden_size", 4096) + intermediate_size = getattr(config, "intermediate_size", 8192) + + super().__init__( + num_experts=num_local_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + up_proj_name="gate_up_proj", + is_gated=True, + down_proj_name="down_proj", + has_bias=False, + skip_bias_add=False, + renormalize=False, + hidden_act="silu", + params_dtype=None, + quant_config=quant_config, + is_use_fused_moe=True, + ) + + # Llama4 uses sigmoid routing, not softmax + # Override topk_softmax to use sigmoid + self._use_sigmoid_routing = True + + # Shared expert (independent from routed experts) + self.shared_expert = FeedForward( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act="silu", + up_proj_name="gate_up_proj", + is_gated=True, + down_proj_name="down_proj", + bias=False, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_expert", + ) + + def topk_softmax(self, expert_logits): + """Override: Llama4 uses sigmoid routing instead of softmax.""" + topk_values, topk_indices = torch.topk( + expert_logits, self.top_k, dim=-1) + topk_values = torch.sigmoid(topk_values.float()) + return topk_values, topk_indices + + def forward(self, hidden_states, residual=None): + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + + # Shared expert output + shared_out = self.shared_expert(hidden_states) + + # Router logits + router_logits, _ = self.gate(hidden_states) + + # Routed experts + routed_out = self.forward_experts(hidden_states, router_logits, None) + + # Combine + final_out = routed_out + shared_out + if self.tp_size > 1: + final_out = tensor_model_parallel_all_reduce(final_out) + + return final_out.view(orig_shape) + + +# ============================================================ +# Llama4Attention hijack +# ============================================================ + +vllm__llama4__Llama4Attention__init__org = Llama4Attention.__init__ + + +def vllm__llama4__Llama4Attention____init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", +) -> None: + vllm__llama4__Llama4Attention__init__org( + self, config, hidden_size, num_heads, num_kv_heads, + max_position_embeddings, quant_config, bias, cache_config, prefix) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: save rope_scaling for MLU RoPE dispatch + ''' + self.rope_scaling = getattr(config, "rope_scaling", None) + ''' + ================== + End of MLU Hijack + ================== + ''' + + +def vllm__llama4__Llama4Attention__forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = None, + smooth_quant_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: MLU RoPE: merge q/k, apply rotary, split back (教训 #3) + For NoPE layers (self.rotary_emb is None), skip RoPE entirely. + ''' + if self.rotary_emb is not None: + if (self.rope_scaling is not None + and self.rope_scaling.get("rope_type") == "longrope"): + q, k = self.rotary_emb(positions, q, k) + else: + qk, _ = qkv.split( + [self.q_size + self.kv_size, self.kv_size], dim=-1) + self.rotary_emb( + positions, + qk.view(-1, self.num_heads + self.num_kv_heads, + self.head_dim)) + q, k = qk.split([self.q_size, self.kv_size], dim=-1) + ''' + ================== + End of MLU Hijack + ================== + ''' + + # QK norm (教训 #2: use contiguous + reshape) + if self.qk_norm is not None: + q = q.contiguous().reshape(-1, self.head_dim) + q = (self.qk_norm(q.float()) + .contiguous().reshape(-1, self.q_size).to(q.dtype)) + k = k.contiguous().reshape(-1, self.head_dim) + k = (self.qk_norm(k.float()) + .contiguous().reshape(-1, self.kv_size).to(k.dtype)) + + # Temperature tuning for NoPE layers + if self.attn_temperature_tuning and self.nope: + attn_scale = self._get_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add residual in o_proj + ''' + output, _ = self.o_proj(attn_output, residual) + ''' + ================== + End of MLU Hijack + ================== + ''' + return output + + +# ============================================================ +# Llama4DecoderLayer hijack +# ============================================================ + +def vllm__llama4__Llama4DecoderLayer____init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", +) -> None: + super(Llama4DecoderLayer, self).__init__() + from vllm.model_executor.models.llama4 import ( + _extract_layer_index, Llama4Attention) + + self.layer_idx = _extract_layer_index(prefix) + self.hidden_size = getattr(config, "hidden_size", 4096) + max_position_embeddings = getattr( + config, "max_position_embeddings", 8192) + + self.self_attn = Llama4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=getattr(config, "num_attention_heads", 32), + num_kv_heads=getattr(config, "num_key_value_heads", + getattr(config, "num_attention_heads", 32)), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + interleave_moe_layer_step = getattr( + config, "interleave_moe_layer_step", 0) + is_moe_layer = (interleave_moe_layer_step > 0 + and (self.layer_idx + 1) + % interleave_moe_layer_step == 0) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Replace MoE with Llama4MoEMlu (SparseMoeMlp + shared expert), + Replace dense MLP with FeedForward. + ''' + if is_moe_layer: + self.feed_forward = Llama4MoEMlu( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + intermediate_size_mlp = getattr( + config, "intermediate_size_mlp", + getattr(config, "intermediate_size", 8192)) + self.feed_forward = FeedForward( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size_mlp, + hidden_act="silu", + up_proj_name="gate_up_proj", + is_gated=True, + down_proj_name="down_proj", + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + self.hidden_size, eps=rms_norm_eps) + + self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant( + quant_config) + self.is_per_token_sq_perf_cases = is_per_token_smoothquant( + quant_config) + if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases: + self.self_attn.qkv_proj.quant_method.skip_quant_input = True + self.quant_fusion_attn_layernorm = None + + +def vllm__llama4__Llama4DecoderLayer__forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, +) -> torch.Tensor: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: use decoder_layer_forward_base with residual-in-matmul + and optional quant fusion. + ''' + attn_layernorm = self.input_layernorm + if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases: + if self.quant_fusion_attn_layernorm is None: + if self.is_per_token_sq_perf_cases: + attn_quant_scale = self.self_attn.qkv_proj.smooth + else: + attn_quant_scale = self.self_attn.qkv_proj.scale_to_int + self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm( + self.input_layernorm, attn_quant_scale, + dynamic_quant=self.is_per_token_sq_perf_cases) + attn_layernorm = self.quant_fusion_attn_layernorm + + return decoder_layer_forward_base( + positions, hidden_states, kv_cache, attn_metadata, + attn_layernorm, + self.self_attn, + self.post_attention_layernorm, + self.feed_forward, + input_norm_fuse_en=self.is_per_token_sq_perf_cases) + + +# ============================================================ +# Llama4Model hijack +# ============================================================ + +def vllm__llama4__Llama4Model__forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, IntermediateTensors]: + return decoder_model_forward_base_pp( + input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, + self.layers, self.start_layer, self.end_layer, + self.get_input_embeddings, + self.norm, + inputs_embeds) + + +# ============================================================ +# Llama4ForCausalLM load_weights hijack +# ============================================================ + +def vllm__llama4__Llama4ForCausalLM__load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], +): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: pack params for SparseMoeMlp (MoE layers) + ''' + for name, m in self.model.named_modules(): + if isinstance(m, SparseMoeMlp): + m.pack_params() + + start_expert_id = 0 + ''' + ================== + End of MLU Hijack + ================== + ''' + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # Permute Q/K weights for rotary embedding + name, loaded_weight = self.permute_qk_weight_for_rotary( + name, loaded_weight) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: remap expert_id for distributed inference + ''' + if (start_expert_id > 0 + and "feed_forward.experts." in name): + match = re.search(r'experts\.\d+', name) + if match: + expert_str = match.group(0) + expert_id = int(expert_str.split(".")[1]) + named_expert_id = expert_id - start_expert_id + name = name.replace( + f"experts.{expert_id}", + f"experts.{named_expert_id}") + ''' + ================== + End of MLU Hijack + ================== + ''' + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + # Skip experts not assigned to this worker + if ("feed_forward.experts." in name + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + # Skip experts not assigned to this worker + if ("feed_forward.experts." in name + and name not in params_dict): + continue + if name not in params_dict: + logger.warning( + "Skipping weight %s not present in the model", name) + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +# ============================================================ +# Apply all hijacks +# ============================================================ + +MluHijackObject.apply_hijack( + Llama4Attention, + Llama4Attention.__init__, + vllm__llama4__Llama4Attention____init__) +MluHijackObject.apply_hijack( + Llama4Attention, + Llama4Attention.forward, + vllm__llama4__Llama4Attention__forward) +MluHijackObject.apply_hijack( + Llama4DecoderLayer, + Llama4DecoderLayer.__init__, + vllm__llama4__Llama4DecoderLayer____init__) +MluHijackObject.apply_hijack( + Llama4DecoderLayer, + Llama4DecoderLayer.forward, + vllm__llama4__Llama4DecoderLayer__forward) +MluHijackObject.apply_hijack( + Llama4Model, + Llama4Model.forward, + vllm__llama4__Llama4Model__forward) +MluHijackObject.apply_hijack( + Llama4ForCausalLM, + Llama4ForCausalLM.load_weights, + vllm__llama4__Llama4ForCausalLM__load_weights)