################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from typing import Any, Iterable, Optional, Union import torch from fastcore.basics import patch_to from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config import vllm from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.deepseek_v2 import ( DeepseekV2ForCausalLM, DeepseekV2Model, FusedMoE, Indexer, PPMissingLayer, default_weight_loader, get_spec_layer_idx_from_weight_name, is_pp_missing_parameter, maybe_prefix, maybe_remap_kv_scale_name, yarn_get_mscale) from vllm.sequence import IntermediateTensors from vllm.utils import cdiv from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm_br.v1.attention.backends.mla.indexer import ( SupaDeepseekV32IndexerBackend) from .supa_module import (DeepseekV2MoE, MergedGateUpMLPSiluL2, SupaMLAModules, SupaMultiHeadLatentAttention) @patch_to(vllm.model_executor.models.deepseek_v2.DeepseekV32IndexerCache) def get_attn_backend(self) -> AttentionBackend: return SupaDeepseekV32IndexerBackend class SupaDeepseekV2MLAAttention(nn.Module): def __init__( self, vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: Optional[int], kv_lora_rank: int, rope_theta: float = 10000, rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", topk_indices_buffer: Optional[torch.Tensor] = None, ) -> None: super().__init__() self.is_v32 = hasattr(config, "index_topk") self.hidden_size = hidden_size self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 self.num_local_heads = num_heads // tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.fused_qkv_a_proj = None self.kv_a_proj_with_mqa = None self.q_a_proj = None self.q_a_layernorm = None self.q_b_proj = None self.q_proj = None if self.is_v32: if self.q_lora_rank is not None: self.fused_qkv_a_proj = MergedColumnParallelLinear( self.hidden_size, [ self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim ], bias=False, quant_config=quant_config, prefix=f"{prefix}.fused_qkv_a_proj", disable_tp=True) self.fused_qkv_a_proj.no_need_cross = True else: self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_a_proj_with_mqa") else: if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_lora_rank, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_a_proj") self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_a_proj_with_mqa") if self.q_lora_rank is not None: self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_b_proj") else: self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_proj") self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj") if rope_scaling: if self.is_v32: rope_scaling["rope_type"] = 'deepseek_yarn' else: rope_scaling["rope_type"] = 'deepseek_yarn_supa' self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale if self.is_v32: self.indexer: Optional[SupaIndexer] = SupaIndexer( vllm_config, config, hidden_size, q_lora_rank, quant_config, cache_config, topk_indices_buffer, f"{prefix}.indexer") else: self.indexer: Optional[SupaIndexer] = None mla_modules = SupaMLAModules( kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, rotary_emb=self.rotary_emb, o_proj=self.o_proj, fused_qkv_a_proj=self.fused_qkv_a_proj, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, q_a_layernorm=self.q_a_layernorm, q_b_proj=self.q_b_proj, q_proj=self.q_proj, indexer=self.indexer, is_sparse=self.is_v32, topk_indices_buffer=topk_indices_buffer, q_a_proj=self.q_a_proj, ) self.mla_attn = SupaMultiHeadLatentAttention( self.hidden_size, self.num_local_heads, self.scaling, self.qk_nope_head_dim, self.qk_rope_head_dim, self.v_head_dim, self.q_lora_rank, self.kv_lora_rank, mla_modules, cache_config, quant_config, prefix, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: return self.mla_attn(positions, hidden_states, is_ds_v32=self.is_v32) def indexer_k_cache( k: torch.Tensor, # [num_tokens, head_dim] # (8, 128) kv_cache: torch. Tensor, # [1, num_blocks, block_size, cache_stride] # (1, 1024, 2048, 128) slot_mapping: torch.Tensor, # [num_tokens] # (8) ) -> None: num_tokens = k.shape[0] head_dim = k.shape[1] # [TODO] kv_cache shape is not aligned with nv cache_block_size = kv_cache.shape[-2] for idx in range(num_tokens): slot_idx = slot_mapping[idx] k_idx = k[idx] block_idx = slot_idx // cache_block_size block_offset = slot_idx % cache_block_size kv_cache[0][block_idx][ block_offset][: head_dim] = k_idx # [TODO] kv cache stride is longer than head_dim def bf16_mqa_logits( q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, ): seq_len_kv = kv.shape[0] k = kv q = q.float() k = k.float() mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]) mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]) mask = mask_lo & mask_hi score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf")) return logits def _ref_fp8_paged_mqa_logits( q: torch.Tensor, kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, max_model_len: int, ): batch_size, next_n, _, _ = q.size() _, num_block, block_size, unkonw_size, head_dim = kv_cache.size( ) # [1, num_block, block_size, _] num_block = num_block * 16 block_size = block_size // 16 kv_cache = kv_cache.view(num_block, block_size, unkonw_size, head_dim) logits = torch.full( [batch_size * next_n, max_model_len], float("-inf"), device=q.device, dtype=torch.float32, ) context_lens_list = context_lens.tolist() for i in range(batch_size): context_len = context_lens_list[i] q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose( 0, 1).contiguous()) for block_rk in range(cdiv(context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] k_offsets = torch.arange( block_rk * block_size, (block_rk + 1) * block_size, device="cuda", ) mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) s = torch.where( mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( logits.dtype), float("-inf"), ) s = torch.relu(s) * weight_slice[..., None] s = s.sum(dim=0) logits[ i * next_n:(i + 1) * next_n, block_rk * block_size:(block_rk + 1) * block_size, ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) return logits def cp_gather_indexer_k_quant_cache( kv_cache, # [1, num_blocks, block_size, head_dim + 1] dst_value, # [cu_seq_lens[-1], head_dim] dst_scale, # [cu_seq_lens[-1], 4] block_table, # [batch_size, num_blocks] cu_seq_lens, # [batch_size + 1, ] batch_size, ): _, num_blocks, block_size, _ = kv_cache.shape # align to nv num_blocks = num_blocks * 16 block_size = block_size // 16 head_dim = dst_value.shape[-1] kv_cache = kv_cache.view(num_blocks, -1) expected_value = [] # expected_scale = [] for b in range(batch_size): s = cu_seq_lens[b + 1] - cu_seq_lens[b] if s == 0: continue tot = cdiv(s, block_size) blocks = block_table[b, :tot] value = [] full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) # [TODO] not support index in tensor on br, run in cpu now non_remaining_value = kv_cache.cpu()[ blocks.cpu()[full_block.cpu()], :block_size * head_dim].view( -1, head_dim) # non_remaining_scale = kv_cache[blocks[full_block], # block_size * head_dim:].view(-1, 4) remaining = s - (tot - 1) * block_size value = torch.cat([ non_remaining_value, kv_cache.cpu()[blocks[-1], :remaining * head_dim].view( -1, head_dim) ], dim=0) # scale = torch.cat([ # non_remaining_scale, # kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim + # remaining * 4].view(-1, 4) # ], # dim=0) expected_value.append(value) # expected_scale.append(scale) gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) # gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) gather_value = gather_value.view(torch.bfloat16).to(dst_value.device) # gather_scale = gather_scale.view(torch.float32) dst_value.copy_(gather_value) # dst_scale.copy_(gather_scale) def sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, q_fp8: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, scale_fmt: Optional[str], topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, topk_indices_buffer: Optional[torch.Tensor], ) -> torch.Tensor: # profile run # NOTE(Chen): create the max possible flattened_kv. So that # profile_run can get correct memory usage. support_fp8 = False if support_fp8: _flattened_kv = torch.empty([total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8) _k_fp8 = _flattened_kv[..., :head_dim].view( torch.float8_e4m3fn).contiguous() _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() else: _flattened_kv = torch.empty([total_seq_lens, head_dim + 4], device=k.device, dtype=torch.bfloat16) return topk_indices_buffer def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, q_fp8: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, scale_fmt: Optional[str], topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, topk_indices_buffer: Optional[torch.Tensor] = None, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, kv_cache, q_fp8, k, weights, quant_block_size, scale_fmt, topk_tokens, head_dim, max_model_len, total_seq_lens, topk_indices_buffer, ) assert topk_indices_buffer is not None attn_metadata = attn_metadata[k_cache_prefix] assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) slot_mapping = attn_metadata.slot_mapping has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens indexer_k_cache( k, kv_cache, slot_mapping, ) topk_indices_buffer[:hidden_states.shape[1]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill for chunk in prefill_metadata.chunks: k_bf16 = torch.empty([chunk.total_seq_lens, head_dim], device=k.device, dtype=torch.bfloat16) k_scale = None cp_gather_indexer_k_quant_cache( kv_cache, k_bf16, k_scale, chunk.block_table, chunk.cu_seq_lens, chunk.num_reqs, ) logits = bf16_mqa_logits( q_fp8[chunk.token_start:chunk.token_end], k_bf16, weights[chunk.token_start:chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) # [TODO] topk is not aligned with cpu if elements are -inf topk_indices = logits.cpu().topk(min(topk_tokens, logits.shape[-1]), dim=-1)[1].supa() topk_indices -= chunk.cu_seqlen_ks[:, None] mask_lo = topk_indices >= 0 mask_hi = topk_indices - (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks)[:, None] < 0 mask = torch.full_like(topk_indices, False, dtype=torch.bool, device=topk_indices.device) mask = mask_lo & mask_hi topk_indices = topk_indices.masked_fill(~mask, -1) topk_indices_buffer[ chunk.token_start:chunk.token_end, :topk_indices. shape[-1]] = topk_indices.to(dtype=torch.int32) if has_decode: decode_metadata = attn_metadata.decode # kv_cache size requirement [num_block, block_size, n_head, head_dim], # we only have [num_block, block_size, head_dim], kv_cache = kv_cache.unsqueeze(-2) decode_lens = decode_metadata.decode_lens if decode_metadata.requires_padding: # pad in edge case where we have short chunked prefill length < # decode_threshold since we unstrictly split # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) padded_q_fp8_decode_tokens = pack_seq_triton( q_fp8[:num_decode_tokens], decode_lens) else: padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( decode_lens.shape[0], -1, *q_fp8.shape[1:]) # TODO: move and optimize below logic with triton kernels batch_size = padded_q_fp8_decode_tokens.shape[0] next_n = padded_q_fp8_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n logits = _ref_fp8_paged_mqa_logits( padded_q_fp8_decode_tokens, kv_cache, weights[:num_padded_tokens], decode_metadata.seq_lens, decode_metadata.block_table, max_model_len=max_model_len, ) # padded query len current_device = padded_q_fp8_decode_tokens.device padded_num_tokens = batch_size * next_n positions = torch.arange(max_model_len, device=current_device).unsqueeze(0).expand( batch_size * next_n, -1) row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n next_n_offset = torch.arange( padded_num_tokens, device=padded_q_fp8_decode_tokens.device) % next_n index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1) # index_end_pos: [B * N, 1] mask = positions <= index_end_pos # mask: [B * N, L] logits = logits.masked_fill(~mask, float('-inf')) # [TODO] topk is not supported device = logits.device logits = logits.to('cpu') topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] topk_indices = topk_indices.to(device) # ensure we don't set indices for the top k # that is out of range(masked already) # this will happen if context length is shorter than K topk_indices[topk_indices > index_end_pos] = -1 if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens topk_indices = unpack_seq_triton( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens) topk_indices_buffer[:num_decode_tokens, :topk_indices. shape[-1]] = topk_indices.to(dtype=torch.int32) return topk_indices_buffer class SupaIndexer(Indexer): def __init__(self, vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, q_lora_rank: Optional[int], quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], topk_indices_buffer: Optional[torch.Tensor] = None, prefix: str = "") -> None: super().__init__( vllm_config=vllm_config, config=config, hidden_size=hidden_size, q_lora_rank=q_lora_rank, quant_config=quant_config, cache_config=cache_config, topk_indices_buffer=topk_indices_buffer, prefix=prefix, ) self.n_head = config.index_n_heads # 64 self.weights_proj = ReplicatedLinear(hidden_size, self.n_head, bias=False, quant_config=None, prefix=f"{prefix}.weights_proj") self.k_cache.dtype = torch.bfloat16 self.k_cache.head_dim = config.index_head_dim self.topk_indices_buffer.fill_(0) def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) q_pe, q_nope = torch.split( q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) k, _ = self.wk(hidden_states) k = k.view(-1, self.head_dim) k = self.k_norm(k) k_pe, k_nope = torch.split( k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) q = torch.cat([q_pe, q_nope], dim=-1) k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) support_fp8 = False if support_fp8: q_fp8, q_scale = per_token_group_quant_fp8( q, self.quant_block_size, column_major_scales=False, use_ue8m0=self.scale_fmt is not None) q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) q_scale = q_scale.view(-1, self.n_head, 1) weights, _ = self.weights_proj(hidden_states) weights = weights.unsqueeze( -1) * q_scale * self.softmax_scale * self.n_head**-0.5 weights = weights.squeeze(-1) return torch.ops.vllm.sparse_attn_indexer( hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], q_fp8, k, weights, self.quant_block_size, self.scale_fmt, self.topk_tokens, self.head_dim, self.max_model_len, self.max_total_seq_len, self.topk_indices_buffer, ) else: q = q.view(-1, self.n_head, self.head_dim) weights, _ = self.weights_proj(hidden_states) weights = weights.view(-1, self.n_head) weights = weights.unsqueeze( -1) * self.softmax_scale * self.n_head**-0.5 weights = weights.squeeze(-1) return sparse_attn_indexer( hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], q, k, weights, self.quant_block_size, self.scale_fmt, self.topk_tokens, self.head_dim, self.max_model_len, self.max_total_seq_len, self.topk_indices_buffer, ) @patch_to(DeepseekV2Model) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] residual = residual.unsqueeze(0) # NOTE: SUPA wants 3D input hidden_states = hidden_states.unsqueeze(0) for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states.squeeze(0) if hidden_states is not None else hidden_states, "residual": residual.squeeze(0) if residual is not None else residual }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states.squeeze(0) @patch_to(DeepseekV2ForCausalLM) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(DeepseekV2ForCausalLM, self).__init__() config = vllm_config.model_config.hf_config model_config = vllm_config.model_config model_config.use_ds_mla = True is_v32 = hasattr(config, "index_topk") if is_v32: model_config.use_ds_mla_sparse = True quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the # quant_method for relevant layers during initialization. self.fuse_qkv_a_proj = hasattr( config, "q_lora_rank") and config.q_lora_rank is not None if self.fuse_qkv_a_proj: self.packed_modules_mapping["fused_qkv_a_proj"] = [ "q_a_proj", "kv_a_proj_with_mqa", ] self.model = DeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @patch_to(DeepseekV2ForCausalLM) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if (("mlp.experts." in name) and name not in params_dict): continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name if ((param_name == "fused_qkv_a_proj") and name_mapped not in params_dict): continue else: name = name_mapped # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue if name not in params_dict: # logger.debug(f'skip {name}') continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) # weight layout infer if name.find("norm.weight") != -1 or name.find( "e_score_correction_bias") != -1: param.data = param.data.to(torch.float32) torch.supa.empty_cache() break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): continue if name not in params_dict: # logger.debug(f'skip {name}') continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id) # weight layout infer if name.find("norm.weight") != -1 or name.find( "e_score_correction_bias") != -1: param.data = param.data.to(torch.float32) torch.supa.empty_cache() break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue if name not in params_dict: # logger.debug(f'skip {name}') continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # weight layout infer if name.find("norm.weight") != -1 or name.find( "e_score_correction_bias") != -1: param.data = param.data.to(torch.float32) torch.supa.empty_cache() loaded_params.add(name) return loaded_params vllm.model_executor.models.deepseek_v2.DeepseekV2MLP = MergedGateUpMLPSiluL2 logger.debug('[Patch] patch DeepSeekV2 MLP with MergedGateUpMLPSiluL2') vllm.model_executor.models.deepseek_v2.DeepseekV2MoE = DeepseekV2MoE logger.debug('[Patch] patch DeepSeekV2 MoE with DeepseekV2MoE') vllm.model_executor.models.deepseek_v2.DeepseekV2MLAAttention = SupaDeepseekV2MLAAttention logger.debug('[Patch] patch DeepSeekV2 MLA with SupaDeepseekV2MLAAttention') vllm.model_executor.models.deepseek_v2.Indexer = SupaIndexer logger.debug('[Patch] patch DeepSeekV2 Indexer with SupaIndexer') vllm.model_executor.models.deepseek_v2.MultiHeadLatentAttention = SupaMultiHeadLatentAttention logger.debug( '[Patch] patch DeepSeekV2 MultiHeadLatentAttention with SupaMultiHeadLatentAttention' ) # vllm.model_executor.models.deepseek_v2.DeepseekV2ForCausalLM.packed_modules_mapping = { # "gate_up_proj": ["gate_proj", "up_proj"], # # "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] # } # logger.debug( # '[Patch] patch DeepseekV2ForCausalLM with SupportsQuant packed_modules_mapping' # )