diff --git a/vllm_kunlun/models/deepseek_v2.py b/vllm_kunlun/models/deepseek_v2.py index de63226..52478ba 100644 --- a/vllm_kunlun/models/deepseek_v2.py +++ b/vllm_kunlun/models/deepseek_v2.py @@ -54,7 +54,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm_kunlun.ops.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm_kunlun.ops.attention.mla import MLAModules, MultiHeadLatentAttention +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE diff --git a/vllm_kunlun/ops/attention/mla.py b/vllm_kunlun/ops/attention/mla.py deleted file mode 100644 index e50ac2f..0000000 --- a/vllm_kunlun/ops/attention/mla.py +++ /dev/null @@ -1,180 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass -from typing import Optional - -import torch - -from vllm_kunlun.ops.attention.layer import Attention -# from vllm.attention import Attention -from vllm.config import CacheConfig -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization import QuantizationConfig - - -@dataclass -class MLAModules: - """Modules used in MLA. - """ - kv_a_layernorm: torch.nn.Module - kv_b_proj: torch.nn.Module - rotary_emb: torch.nn.Module - o_proj: torch.nn.Module - fused_qkv_a_proj: Optional[torch.nn.Module] - kv_a_proj_with_mqa: Optional[torch.nn.Module] - q_a_layernorm: Optional[torch.nn.Module] - q_b_proj: Optional[torch.nn.Module] - q_proj: Optional[torch.nn.Module] - indexer: Optional[torch.nn.Module] - is_sparse: bool - topk_indices_buffer: Optional[torch.Tensor] - - -@CustomOp.register("vllm_kunlun_multi_head_latent_attention") -class MultiHeadLatentAttention(CustomOp): - """MLA layer registered as CustomOp. - Note that currently MLA ignores the enable/disable mechanism of CustomOp - because there is only one in-tree implementation in forward_native. - TODO: implement this with a new PluggableLayer mechanism. - - This class takes positions and hidden_states as input. - The input tensors can either contain prefill tokens or decode tokens. - The class does the following: - - 1. MLA Preprocess. - 2. Perform multi-head attention to prefill tokens and - multi-query attention to decode tokens separately. - 3. Return the output tensor. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - scale: float, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - mla_modules: MLAModules, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.num_heads = num_heads - self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj - self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa - self.q_a_layernorm = mla_modules.q_a_layernorm - self.q_b_proj = mla_modules.q_b_proj - self.q_proj = mla_modules.q_proj - self.kv_a_layernorm = mla_modules.kv_a_layernorm - self.kv_b_proj = mla_modules.kv_b_proj - self.rotary_emb = mla_modules.rotary_emb - self.o_proj = mla_modules.o_proj - self.indexer = mla_modules.indexer - self.is_sparse = mla_modules.is_sparse - - if self.indexer is not None: - assert hasattr(self.indexer, "topk_tokens") - self.topk_tokens = self.indexer.topk_tokens - self.topk_indices_buffer = mla_modules.topk_indices_buffer - - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=scale, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - use_sparse=mla_modules.is_sparse, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - kv_b_proj=self.kv_b_proj, - indexer=self.indexer, - ) - - self.prefix = prefix - - def forward_native( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - q_c = None - kv_lora = None - - if self.q_lora_rank is not None: - assert self.fused_qkv_a_proj is not None, \ - "fused_qkv_a_proj is required when q_lora_rank is not None" - assert self.q_a_layernorm is not None, \ - "q_a_layernorm is required when q_lora_rank is not None" - assert self.q_b_proj is not None, \ - "q_b_proj is required when q_lora_rank is not None" - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_lora = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] - else: - assert self.kv_a_proj_with_mqa is not None, \ - "kv_a_proj_with_mqa is required when q_lora_rank is None" - assert self.q_proj is not None, \ - "q_proj is required when q_lora_rank is None" - kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] - q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], - dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) - - q = q.view(-1, self.num_heads, self.qk_head_dim) - # Add head dim of 1 to k_pe - k_pe = k_pe.unsqueeze(1) - - q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim:], k_pe) - - if self.indexer and self.is_sparse: - _topk_indices = self.indexer(hidden_states, q_c, positions, - self.rotary_emb) - - hidden_states_shape_0 = 0 - if isinstance(hidden_states, tuple): - x_q, x_scale = hidden_states - hidden_states_shape_0 = x_q.shape[0] - else: - hidden_states_shape_0 = hidden_states.shape[0] - attn_out = self.mla_attn( - q, - kv_c_normed, - k_pe, - output_shape=(hidden_states_shape_0, - self.num_heads * self.v_head_dim)) - return self.o_proj(attn_out)[0] - - def forward_cuda(self, *args, **kwargs): - return self.forward_native(*args, **kwargs) \ No newline at end of file