# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The vLLM team. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" import typing from collections.abc import Callable, Iterable from itertools import islice from typing import Any import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config from vllm._aiter_ops import rocm_aiter_ops from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata, ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .utils import ( PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops as ops import ixformer.inference.functions as ixfops logger = init_logger(__name__) class DeepseekAttention(nn.Module): """Normal MHA implementation used by Deepseek v1.""" def __init__( self, vllm_config: VllmConfig, config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, num_heads: int, rope_theta: float = 10000, rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", **kwargs, ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_key_value_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class DeepseekV2MLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: QuantizationConfig | None = None, reduce_results: bool = True, is_sequence_parallel=False, prefix: str = "", ) -> None: super().__init__() # If is_sequence_parallel, the input and output tensors are sharded # across the ranks within the tp_group. In this case the weights are # replicated and no collective ops are needed. # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, disable_tp=is_sequence_parallel, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, disable_tp=is_sequence_parallel, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": raise ValueError( f"Unsupported activation: {hidden_act}. Only silu is supported for now." ) self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class DeepseekV2MoE(nn.Module): def __init__( self, config: DeepseekV2Config | DeepseekV3Config, parallel_config: ParallelConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) self.ep_group = get_ep_group().device_group self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if config.hidden_act != "silu": raise ValueError( f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now." ) self.gate = ReplicatedLinear( config.hidden_size, config.n_routed_experts, bias=False, quant_config=None, prefix=f"{prefix}.gate", ) if getattr(config, "topk_method", None) == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts) ) else: self.gate.e_score_correction_bias = None # Load balancing settings. eplb_config = parallel_config.eplb_config self.enable_eplb = parallel_config.enable_eplb self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = self.ep_rank * self.n_local_physical_experts self.physical_expert_end = ( self.physical_expert_start + self.n_local_physical_experts ) self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled: self.shared_experts = None else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, is_sequence_parallel=self.is_sequence_parallel, reduce_results=False, prefix=f"{prefix}.shared_experts", ) self.experts = SharedFusedMoE( shared_experts=self.shared_experts, gate=self.gate, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, num_expert_group=getattr(config, "n_group", 1), topk_group=getattr(config, "topk_group", 1), prefix=f"{prefix}.experts", scoring_func=getattr(config, "scoring_func", "softmax"), # we do scaling outside, set factor to 1.0 to avoid double mul # aiter applies routed_scaling_factor internally routed_scaling_factor=1.0 if not self.is_rocm_aiter_moe_enabled else self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, n_shared_experts=config.n_shared_experts if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() else None, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # Chunk the hidden states so they aren't replicated across TP ranks. # This avoids duplicate computation in self.experts. # TODO: We can replace the all_reduce at the end of attn with a # reduce_scatter instead of chunking here. if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) if self.experts.is_internal_router: # In this case, the gate/router runs inside the FusedMoE class fused_moe_out = self.experts( hidden_states=hidden_states, router_logits=hidden_states ) else: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) fused_moe_out = self.experts( hidden_states=hidden_states, router_logits=router_logits ) shared_output, final_hidden_states = fused_moe_out if self.shared_experts is None: assert shared_output is None # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: if not self.is_rocm_aiter_moe_enabled: final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None shared_output *= 1.0 / self.routed_scaling_factor if self.shared_experts is not None: assert shared_output is not None final_hidden_states += shared_output if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0 ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states ) return final_hidden_states.view(num_tokens, hidden_dim) def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: import math if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekV2Attention(nn.Module): def __init__( self, vllm_config: VllmConfig, config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int, kv_lora_rank: int, rope_theta: float = 10000, rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, topk_indices_buffer: torch.Tensor | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 self.num_local_heads = num_heads // tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings assert topk_indices_buffer is None, ( "topk_indices_buffer is not \ supported for DeepseekV2Attention" ) if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear( self.hidden_size, self.q_lora_rank, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_a_proj", ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_b_proj", ) else: self.q_proj = ColumnParallelLinear( self.hidden_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_proj", ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_a_proj_with_mqa", ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj", ) # O projection. self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, self.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale self.attn = Attention( self.num_local_heads, self.qk_head_dim, self.scaling, num_kv_heads=self.num_local_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", ) def forward_opt( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: if self.q_lora_rank is not None: q_latent_kpe = self.q_a_proj(hidden_states)[0] q, kv_a, k_pe = q_latent_kpe.split([self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], dim=1) q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q_latent_kpe = self.q_proj(hidden_states)[0] q, kv_a, k_pe = q_latent_kpe.split([self.num_heads * self.qk_head_dim, self.kv_lora_rank, self.qk_rope_head_dim], dim=1) q = q.view(-1, self.num_local_heads, self.qk_head_dim) _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.empty_like(q) v = torch.empty(q.shape[0], self.num_local_heads, self.v_head_dim, device=q.device, dtype=q.dtype) ixfops.mla_rope(positions, q_pe, k_pe, k[...,self.qk_nope_head_dim:], self.rotary_emb.cos_sin_cache) ixfops.mla_copy_kv(k_nope, v_nope, k, v) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_proj(hidden_states)[0] kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) q = q.view(-1, self.num_local_heads, self.qk_head_dim) _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim :] = k_pe # padding value to qk_head_dim for alignment v = torch.nn.functional.pad( v, [0, self.qk_head_dim - self.v_head_dim], value=0 ).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v) attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ ..., : self.v_head_dim ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): def __init__( self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig ): super().__init__() self.kv_cache = [torch.tensor([])] self.head_dim = head_dim self.prefix = prefix self.cache_config = cache_config self.dtype = dtype compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: return MLAAttentionSpec( # Only has one vector instead of K + V block_size=self.cache_config.block_size, num_kv_heads=1, head_size=self.head_dim, dtype=self.dtype, ) def forward(self): ... def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend @torch.inference_mode() def cp_gather_indexer_k_quant_cache( kv_cache, # [num_blocks, block_size, head_dim] dst_value, # [cu_seq_lens[-1], head_dim] block_table, # [batch_size, num_blocks] cu_seq_lens, # [batch_size + 1, ] batch_size, ): num_blocks, block_size, _ = kv_cache.shape head_dim = dst_value.shape[-1] kv_cache = kv_cache.view(num_blocks, -1) expected_value = [] # expected_scale = [] for b in range(batch_size): s = cu_seq_lens[b + 1] - cu_seq_lens[b] if s == 0: continue tot = cdiv(s, block_size) blocks = block_table[b, :tot] value = [] scale = [] full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) non_remaining_value = kv_cache[blocks[full_block], :block_size * head_dim].view(-1, head_dim) # non_remaining_scale = kv_cache[blocks[full_block], # block_size * head_dim:].view(-1, 4) remaining = s - (tot - 1) * block_size value = torch.cat([ non_remaining_value, kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim) ], dim=0) # scale = torch.cat([ # non_remaining_scale, # kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim + # remaining * 4].view(-1, 4) # ], # dim=0) expected_value.append(value) # expected_scale.append(scale) gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) # gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) gather_value = gather_value.view(torch.bfloat16) # gather_scale = gather_scale.view(torch.float32) dst_value.copy_(gather_value) # dst_scale.copy_(gather_scale) def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, kv_cache, q, k, weights, topk_tokens, head_dim, max_model_len, total_seq_lens, topk_indices_buffer, ) attn_metadata = attn_metadata[k_cache_prefix] assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) slot_mapping = attn_metadata.slot_mapping has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens ops.indexer_k_cache( k, kv_cache, slot_mapping ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill for chunk in prefill_metadata.chunks: k = torch.empty( [chunk.total_seq_lens, head_dim], device=k.device, dtype=torch.bfloat16, ) # k_scale = torch.empty( # [chunk.total_seq_lens, 4], # device=k.device, # dtype=torch.uint8, # ) cp_gather_indexer_k_quant_cache( kv_cache, k, chunk.block_table, chunk.cu_seq_lens, chunk.num_reqs, ) logits = ops.ref_mqa_logits( q[chunk.token_start:chunk.token_end], k, weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), dim=-1)[1] topk_indices -= chunk.cu_seqlen_ks[:, None] mask_lo = topk_indices >= 0 mask_hi = topk_indices - (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks)[:, None] < 0 mask = torch.full_like(topk_indices, False, dtype=torch.bool, device=topk_indices.device) mask = mask_lo & mask_hi topk_indices = topk_indices.masked_fill(~mask, -1) topk_indices_buffer[ chunk.token_start:chunk.token_end, :topk_indices. shape[-1]] = topk_indices.to(dtype=torch.int32) if has_decode: decode_metadata = attn_metadata.decode # kv_cache size requirement [num_block, block_size, n_head, head_dim], # we only have [num_block, block_size, head_dim], kv_cache = kv_cache.unsqueeze(-2) decode_lens = decode_metadata.decode_lens if decode_metadata.requires_padding: # pad in edge case where we have short chunked prefill length < # decode_threshold since we unstrictly split # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) padded_q_decode_tokens = pack_seq_triton( q[:num_decode_tokens], decode_lens) else: padded_q_decode_tokens = q[:num_decode_tokens].reshape( decode_lens.shape[0], -1, *q.shape[1:]) # TODO: move and optimize below logic with triton kernels batch_size = padded_q_decode_tokens.shape[0] next_n = padded_q_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n logits = ops.ref_paged_mqa_logits( padded_q_decode_tokens, kv_cache, weights[:num_padded_tokens], decode_metadata.seq_lens, decode_metadata.block_table, max_model_len=max_model_len, ) # padded query len current_device = padded_q_decode_tokens.device padded_num_tokens = batch_size * next_n positions = torch.arange(max_model_len, device=current_device).unsqueeze(0).expand( batch_size * next_n, -1) row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n next_n_offset = torch.arange( padded_num_tokens, device=padded_q_decode_tokens.device) % next_n index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1) # index_end_pos: [B * N, 1] mask = positions <= index_end_pos # mask: [B * N, L] logits = logits.masked_fill(~mask, float('-inf')) topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] # ensure we don't set indices for the top k # that is out of range(masked already) # this will happen if context length is shorter than K topk_indices[topk_indices > index_end_pos] = -1 if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens topk_indices = unpack_seq_triton( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens, ) topk_indices_buffer[:num_decode_tokens, :topk_indices. shape[-1]] = topk_indices.to(dtype=torch.int32) return topk_indices_buffer def sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: # profile run # NOTE(Chen): create the max possible flattened_kv. So that # profile_run can get correct memory usage. _flattened_kv = torch.empty([total_seq_lens, head_dim], device=k.device, dtype=torch.bfloat16) _k = _flattened_kv[..., :head_dim].view( torch.bfloat16).contiguous() # _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer direct_register_custom_op( op_name="sparse_attn_indexer", op_func=sparse_attn_indexer, mutates_args=["topk_indices_buffer"], fake_impl=sparse_attn_indexer_fake, dispatch_key=current_platform.dispatch_key, ) class Indexer(nn.Module): def __init__( self, vllm_config: VllmConfig, config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, q_lora_rank: int, quant_config: QuantizationConfig | None, cache_config: CacheConfig | None, topk_indices_buffer: torch.Tensor | None, prefix: str = "", ): super().__init__() self.vllm_config = vllm_config self.config = config # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] self.topk_tokens = config.index_topk self.n_head = config.index_n_heads # 64 self.head_dim = config.index_head_dim # 128 self.rope_dim = config.qk_rope_head_dim # 64 self.q_lora_rank = q_lora_rank # 1536 # no tensor parallel, just replicated self.wq_b = ReplicatedLinear( self.q_lora_rank, self.head_dim * self.n_head, bias=False, quant_config=quant_config, prefix=f"{prefix}.wq_b", ) self.wk = ReplicatedLinear( hidden_size, self.head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.wk", ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.weights_proj = ReplicatedLinear( hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj" ) self.softmax_scale = self.head_dim**-0.5 self.scale_fmt = "ue8m0" self.quant_block_size = 128 # TODO: get from config self.topk_indices_buffer = topk_indices_buffer # NOTE: (zyongye) we use fp8 naive cache, # where we store value in fp8 and scale in fp32 # per self.quant_block_size element self.k_cache = DeepseekV32IndexerCache( head_dim=self.head_dim, dtype=torch.bfloat16, prefix=f"{prefix}.k_cache", cache_config=cache_config, ) self.max_model_len = vllm_config.model_config.max_model_len self.prefix = prefix from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) def forward( self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb ) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) q_pe, q_nope = torch.split( q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) k, _ = self.wk(hidden_states) k = self.k_norm(k) k_pe, k_nope = torch.split( k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) q = torch.cat([q_pe, q_nope], dim=-1) k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion # q = q.view(-1, self.head_dim) # q_fp8, q_scale = per_token_group_quant_fp8( # q, # self.quant_block_size, # column_major_scales=False, # use_ue8m0=self.scale_fmt is not None, # ) # q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) # q_scale = q_scale.view(-1, self.n_head, 1) weights, _ = self.weights_proj(hidden_states) weights = ( weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5 ) weights = weights.squeeze(-1) return torch.ops.vllm.sparse_attn_indexer( hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], q, k, weights, self.topk_tokens, self.head_dim, self.max_model_len, self.max_total_seq_len, self.topk_indices_buffer, ) class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). For more info see MLACommonImpl in: vllm/v1/attention/backends/mla/utils.py """ def __init__( self, vllm_config: VllmConfig, config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int | None, kv_lora_rank: int, rope_theta: float = 10000, rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", topk_indices_buffer: torch.Tensor | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 self.num_local_heads = num_heads // tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_lora_rank, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_a_proj") self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_b_proj") else: self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_proj") self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_a_proj_with_mqa") self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj", ) self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, self.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale self.is_v32 = hasattr(config, "index_topk") if self.is_v32: self.indexer = Indexer( vllm_config, config, hidden_size, q_lora_rank, quant_config, cache_config, topk_indices_buffer, f"{prefix}.indexer", ) else: self.indexer = None mla_modules = MLAModules( kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, rotary_emb=self.rotary_emb, o_proj=self.o_proj, q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa if self.q_lora_rank is None else None, q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else None, indexer=self.indexer, is_sparse=self.is_v32, topk_indices_buffer=topk_indices_buffer, ) self.mla_attn = MultiHeadLatentAttentionWrapper( self.hidden_size, self.num_local_heads, self.scaling, self.qk_nope_head_dim, self.qk_rope_head_dim, self.v_head_dim, self.q_lora_rank, self.kv_lora_rank, mla_modules, cache_config, quant_config, prefix, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: return self.mla_attn(positions, hidden_states) class DeepseekV2DecoderLayer(nn.Module): def __init__( self, vllm_config: VllmConfig, prefix: str, config: DeepseekV2Config | None = None, topk_indices_buffer: torch.Tensor | None = None, ) -> None: super().__init__() if config is None: config = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) moe_layer_freq = getattr(config, "moe_layer_freq", 1) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx # verify MLA attention specific fields qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) v_head_dim = getattr(config, "v_head_dim", 0) kv_lora_rank = getattr(config, "kv_lora_rank", 0) use_mha = config.model_type == "deepseek" or all( dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) ) if use_mha: attn_cls = DeepseekAttention elif model_config.use_mla: attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( vllm_config=vllm_config, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, qk_nope_head_dim=qk_nope_head_dim, qk_rope_head_dim=qk_rope_head_dim, v_head_dim=v_head_dim, q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, kv_lora_rank=kv_lora_rank, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", topk_indices_buffer=topk_indices_buffer, ) if ( config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % moe_layer_freq == 0 ): self.mlp = DeepseekV2MoE( config=config, parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) else: self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: residual = hidden_states.clone() hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) if ( not isinstance(self.self_attn, DeepseekAttention) and hidden_states.dtype == torch.float16 ): # Fix FP16 overflow # We scale both hidden_states and residual before # rmsnorm, and rmsnorm result would not affect by scale. hidden_states *= 1.0 / self.routed_scaling_factor if self.layer_idx == 0: # The residual is shared by all layers, we only scale it on # first layer. residual *= 1.0 / self.routed_scaling_factor # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: # Fix FP16 overflow # Scaling the DeepseekV2MLP output, it is the input of # input_layernorm of next decoder layer. # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE hidden_states *= 1.0 / self.routed_scaling_factor return hidden_states, residual @support_torch_compile class DeepseekV2Model(nn.Module): fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.device = current_platform.device_type self.vocab_size = config.vocab_size self.is_v32 = hasattr(config, "index_topk") if self.is_v32: topk_tokens = config.index_topk topk_indices_buffer = torch.empty( vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, device=self.device, ) else: topk_indices_buffer = None if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.embed_tokens", ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV2DecoderLayer( vllm_config, prefix, topk_indices_buffer=topk_indices_buffer ), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class DeepseekV2MixtureOfExperts(MixtureOfExperts): moe_mlp_layers: list[DeepseekV2MoE] """ List of MoE MLP layers in the model. """ def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None): if example_moe is None: self.num_moe_layers = 0 self.num_expert_groups = 0 self.num_logical_experts = 0 self.num_physical_experts = 0 self.num_local_physical_experts = 0 self.num_routed_experts = 0 self.num_shared_experts = 0 self.num_redundant_experts = 0 logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.") else: self.num_logical_experts = example_moe.n_logical_experts self.num_physical_experts = example_moe.n_physical_experts self.num_local_physical_experts = example_moe.n_local_physical_experts self.num_routed_experts = example_moe.n_routed_experts self.num_shared_experts = example_moe.n_shared_experts self.num_redundant_experts = example_moe.n_redundant_experts def update_physical_experts_metadata( self, num_physical_experts: int, num_local_physical_experts: int, ) -> None: assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts for moe in self.moe_mlp_layers: moe.n_local_physical_experts = num_local_physical_experts moe.n_physical_experts = num_physical_experts moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() class DeepseekV2ForCausalLM( nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle ): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) self.use_mha = config.model_type == "deepseek" or all( dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) ) if self.use_mha: self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the # quant_method for relevant layers during initialization. # self.fuse_qkv_a_proj = ( # hasattr(config, "q_lora_rank") and config.q_lora_rank is not None # ) # if self.fuse_qkv_a_proj: # self.packed_modules_mapping["fused_qkv_a_proj"] = [ # "q_a_proj", # "kv_a_proj_with_mqa", # ] self.model = DeepseekV2Model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) # Set MoE hyperparameters self.num_moe_layers = ( self.config.num_hidden_layers - self.config.first_k_dense_replace ) self.set_moe_parameters() def set_moe_parameters(self): self.expert_weights = [] self.num_expert_groups = getattr(self.config, "n_group", 1) self.moe_layers = [] self.moe_mlp_layers = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue assert isinstance(layer, DeepseekV2DecoderLayer) if isinstance(layer.mlp, DeepseekV2MoE): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp if example_moe is None: raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") self.num_logical_experts = example_moe.n_logical_experts self.num_physical_experts = example_moe.n_physical_experts self.num_local_physical_experts = example_moe.n_local_physical_experts self.num_routed_experts = example_moe.n_routed_experts self.num_shared_experts = example_moe.n_shared_experts self.num_redundant_experts = example_moe.n_redundant_experts def set_eplb_state( self, expert_load_view: torch.Tensor, logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, ) -> None: for layer_idx, layer in enumerate(self.moe_layers): # Register the expert weights. self.expert_weights.append(layer.get_expert_weights()) layer.set_eplb_state( moe_layer_idx=layer_idx, expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, ) def update_physical_experts_metadata( self, num_physical_experts: int, num_local_physical_experts: int, ) -> None: assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts self.num_redundant_experts = (num_physical_experts - self.num_logical_experts) for layer in self.model.layers: if isinstance(layer.mlp, DeepseekV2MoE): moe = layer.mlp moe.n_local_physical_experts = num_local_physical_experts moe.n_physical_experts = num_physical_experts moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts, num_redundant_experts=0, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: rocm_aiter_moe_shared_expert_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] # mla_params_mapping = [ # ("fused_qkv_a_proj", "q_a_proj", 0), # ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), # ] mha_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] if self.use_mha: stacked_params_mapping.extend(mha_params_mapping) # else: # stacked_params_mapping.extend(mla_params_mapping) # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts + ( self.config.n_shared_experts if rocm_aiter_moe_shared_expert_enabled else 0 ), num_redundant_experts=self.num_redundant_experts, ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: try: if "rotary_emb.inv_freq" in name: continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if (("mlp.experts." in name) and name not in params_dict): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = name.replace(weight_name, param_name) if is_pp_missing_parameter(name_mapped, self): continue param = params_dict[name_mapped] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. weight_loader = typing.cast(Callable[..., bool], param.weight_loader) success = weight_loader(param, loaded_weight, name_mapped, shard_id=shard_id, expert_id=expert_id, return_success=True) if success: name = name_mapped break else: if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank # So we simply skip it continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) except: pass opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"] # add your opt here.. def inject_layer(layer, quant_method, is_mla): q_lora_rank = getattr(layer, "q_lora_rank", None) if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]: if q_lora_rank is not None: layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) if hasattr(layer.q_a_proj, "weight_scale"): layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0) del layer.kv_a_proj_with_mqa.weight_scale elif not is_mla: layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) if hasattr(layer.q_proj, "weight_scale"): layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0) del layer.kv_a_proj_with_mqa.weight_scale else: return del layer.kv_a_proj_with_mqa.weight del layer.kv_a_proj_with_mqa if is_mla: layer.mla_attn.forward = layer.mla_attn.forward_opt else: layer.forward = layer.forward_opt elif quant_method == "GGUFLinearMethod": pass elif quant_method == "AWQMarlinLinearMethod": dtype = layer.kv_a_proj_with_mqa.qweight.dtype assert dtype == torch.int32 if layer.q_lora_rank is not None: layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1) layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1) del layer.kv_a_proj_with_mqa.scales layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1) del layer.kv_a_proj_with_mqa.qzeros elif not is_mla: layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1) layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1) del layer.kv_a_proj_with_mqa.scales layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1) del layer.kv_a_proj_with_mqa.qzeros else: return del layer.kv_a_proj_with_mqa.qweight del layer.kv_a_proj_with_mqa if is_mla: layer.mla_attn.forward = layer.mla_attn.forward_opt else: layer.forward = layer.forward_opt else: pass for _, layer in self.model.named_modules(): if layer.__class__.__name__ in ["DeepseekV2Attention","DeepseekV2MLAAttention"]: if hasattr(layer.kv_a_proj_with_mqa, "scheme"): quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__ else: quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__ if quant_method not in opt_support_quant_method: break inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "DeepseekV2MLAAttention") return loaded_params class DeepseekForCausalLM(DeepseekV2ForCausalLM): pass class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass # Compatibility with # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py def get_spec_layer_idx_from_weight_name( config: DeepseekV2Config | DeepseekV3Config, weight_name: str ) -> int | None: if ( hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0 ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): if weight_name.startswith(f"model.layers.{layer_idx + i}."): return layer_idx + i return None