import functools from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar) import torch from vllm import _custom_ops as ops from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionState, MLAAttentionImpl) from vllm.attention.backends.mla.common import MLACommonMetadata,triton_attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.platforms import current_platform try: from vllm.vllm_flash_attn import flash_attn_varlen_func is_vllm_fa = True except ImportError: is_vllm_fa = False try: # For rocm use upstream flash attention from vllm.attention.backends.flash_attn import flash_attn_varlen_func except ImportError: flash_attn_varlen_func = None T = TypeVar("T", bound="MLACommonMetadata") class MLACommonImpl(): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, # blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments q_lora_rank: Optional[int], kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, rotary_emb: RotaryEmbedding, # q_proj should be q_b_proj if q_lora_rank is not None, but from an # attention backend perspective we rely on the layer to pass in the # correct matrix q_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear, o_proj: RowParallelLinear, positions: torch.Tensor = None, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.rotary_emb = rotary_emb self.use_yarn_rope = isinstance(rotary_emb, DeepseekScalingRotaryEmbedding) self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj self.positions = positions self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = \ functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do # not support different headdims # We don't need to pad V if we are on a hopper system with FA3 self._pad_v = self.vllm_flash_attn_version is None or not ( self.vllm_flash_attn_version == 3 and current_platform.get_device_capability()[0] == 9) def forward( self, layer: AttentionLayer, hidden_states_or_q_c: torch.Tensor, # query in unified attn k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: T, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: raise NotImplementedError( "output is not yet supported for MLAImplBase") # if attn_metadata.is_profile_run and \ # attn_metadata.context_chunk_workspace is not None: # # During the profile run try to simulate to worse case output size # # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` # # since this can be large # _ = torch.empty( # (attn_metadata.context_chunk_workspace.shape[0], # self.num_heads, self.qk_nope_head_dim + self.v_head_dim), # device=k_c_normed.device, # dtype=k_c_normed.dtype, # ) has_decode = attn_metadata.decode_metadata is not None has_prefill = attn_metadata.prefill_metadata is not None # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) # assert hasattr(attn_metadata, "input_positions") if self.positions is not None: positions = self.positions elif hasattr(attn_metadata, "input_positions"): positions = attn_metadata.input_positions else: raise ValueError('no positions') num_prefill_tokens: int = attn_metadata.num_prefill_tokens decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:] decode_k_pe = k_pe[num_prefill_tokens:] decode_input_positions = \ positions[num_prefill_tokens:] prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens] prefill_k_pe = k_pe[:num_prefill_tokens] prefill_input_positions = \ positions[:num_prefill_tokens] prefill_k_c_normed = k_c_normed[:num_prefill_tokens] if has_decode: decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_input_positions, decode_q_pe, decode_k_pe) if has_prefill: prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( prefill_input_positions, prefill_q_pe, prefill_k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( k_c_normed, k_pe.squeeze(1), kv_cache, attn_metadata.slot_mapping.flatten(), kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) # output = torch.empty(attn_metadata.num_prefill_tokens + # attn_metadata.num_decode_tokens, # self.o_proj.output_size, # device=hidden_states_or_q_c.device, # dtype=hidden_states_or_q_c.dtype) if has_prefill: return self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata) if has_decode: return self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) assert False, "mla forward need prefill or decode function" return None