# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ # MLA Common Components This file implements common components for MLA implementations. First we define: Sq as Q sequence length Skv as KV sequence length MLA has two possible ways of computing, a data-movement friendly approach and a compute friendly approach, we generally want to use the compute friendly approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) and the data-movement friendly approach for "decode" (i.e. the ratio Sq / Skv is "large"). NOTE what we deem small and large is currently determined by if its labelled prefill or decode by the scheduler, but this is something we should probably tune. Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: * Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. Below is example of both paths assuming batchsize = 1 ## More Extent Definitions: C Context length, `Skv - Sq` H hidden size N number of attention heads Lq latent dimension for Q 1536 in DSV3 Lkv latent dimension for K/V 512 in DSV3 P nope dimension, no rope. 128 in DSV3 R rope dimension, goes through rope. 64 in DSV3 V V head dim. 128 in DSV3 ## Vector/Matrix Definitions h_t hidden states (input to attention) shape [Sq, H] q_c latent/compressed Q shape [Sq, Lq] q_nope uncompressed Q (no-rope) shape [Sq, N, P] q_pe uncompressed Q (rope) shape [Sq, N, R] kv_c latent/compressed KV shape [Skv, Lkv] k_pe decoupled k position embeddings shape [Skv, R] new_kv_c new kv_c from current iter shape [Sq, Lkv] new_k_pe new k_pe from current iter shape [Sq, R] cache_kv_c cached k_c from previous iters shape [C, Lkv] cache_k_pe cached k_pe from previous iters shape [C, R] W_DQ project h_t to q_c shape [H, Lq] W_UQ project q_c to q_nope shape [Lq, N * P] W_QR project q_c to q_pe shape [Lq, N * R] W_DKV project h_t to kv_c shape [H, Lkv] W_UK project kv_c to k_nope shape [Lkv, N, P] W_KR project h_t to k_pe shape [H, R] W_UV project kv_c to v shape [Lkv, N, V] W_O project v to h_t shape [N * V, H] ## Compute Friendly Approach (i.e. "forward_mha"): q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(Sq, N, P) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) // MHA with QK headdim = P + R // V headdim = V // spda_o shape [Sq, N, V] spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v ) return spda_o @ W_O NOTE: in the actual code, `kv_b_proj` is [W_UK; W_UV] concatenated per head `q_b_proj` is [W_UQ; W_QR] concatenated per head `out_proj` is W_O ## Data-Movement Friendly Approach (i.e. "forward_mqa"): Runtime q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(-1, N, P) ql_nope = einsum("snh,lnh->snl", q, W_UK) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) // MQA with QK headdim = Lkv + R // V headdim = Lkv // spda_o shape [Sq, N, Lkv] // NOTE: this is less compute-friendly since Lkv > P // but is more data-movement friendly since its MQA vs MHA spda_o = scaled_dot_product_attention( torch.cat([ql_nope, q_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1), kv_c ) o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) return o.view(-1, N * V) @ self.num_heads @ W_O ## Chunked Prefill For chunked prefill we want to use the compute friendly algorithm. We are assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` To mitigate this, we chunk the computation of attention with respect to the current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a fixed workspace size. The chunked prefill approach is as follows: MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(Sq, N, P) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) // MHA between queries and new KV // with QK headdim = P + R // V headdim = V // curr_o shape [Sq, N, V] // curr_lse shape [N, Sq], this is just order FA returns curr_o, curr_lse = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), new_v, casual=True, return_softmax_lse=True ) // Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): chunk_start = chunk_idx * MCC chunk_end = min(chunk_start + MCC, C) Sc = chunk_end - chunk_start cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) chunk_o, chunk_lse = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([cache_k_nope_chunk, cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], dim=-1), cache_v_chunk, casual=False, return_softmax_lse=True ) curr_o, curr_lse = merge_attn_states( suffix_output=curr_o, suffix_lse=curr_lse, prefix_output=chunk_o, prefix_lse=chunk_lse, ) return curr_o @ W_O """ import functools from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast, Any if TYPE_CHECKING: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper import torch import torch.nn as nn from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.attention.attention import ( _init_kv_cache_quant, get_attention_context, set_default_quant_scales, should_load_quant_weights, ) from vllm.model_executor.layers.attention.kv_transfer_utils import ( maybe_transfer_kv_layer, ) from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, get_and_maybe_dequant_weights, ) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearBase, UnquantizedLinearMethod, ) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory from vllm.utils.math_utils import cdiv, round_down from vllm.utils.torch_utils import ( direct_register_custom_op, kv_cache_dtype_str_to_dtype, ) from vllm.v1.attention.backend import ( AttentionBackend, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, MLAAttentionImpl, SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.utils import ( get_dcp_local_seq_lens, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills ) from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.selector import get_attn_backend from vllm.v1.kv_cache_interface import ( AttentionSpec, KVCacheSpec, MLAAttentionSpec, ) logger = init_logger(__name__) class MLAAttention(nn.Module, AttentionLayerBase): """Multi-Head Latent Attention layer. NOTE: Please read the comment at the top of the file before trying to understand this class This class takes query, and compressed key/value tensors as input. The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. 3. Return the output tensor. """ def __init__( self, num_heads: int, scale: float, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int | None, kv_lora_rank: int, kv_b_proj: ColumnParallelLinear, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", use_sparse: bool = False, indexer: object | None = None, rotary_emb: object | None = None, **extra_impl_args, ): super().__init__() self.num_heads = num_heads self.scale = scale self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.kv_b_proj = kv_b_proj self.head_size = kv_lora_rank + qk_rope_head_dim self.layer_name = prefix self.indexer = indexer self.rotary_emb = rotary_emb speculative_config = get_current_vllm_config().speculative_config self.use_spec_decode = ( speculative_config is not None and speculative_config.num_speculative_tokens is not None and speculative_config.num_speculative_tokens > 0 ) self.num_kv_heads = 1 self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False self.quant_config = quant_config # Initialize KV cache quantization attributes self.kv_cache_dtype = kv_cache_dtype self.calculate_kv_scales = calculate_kv_scales _init_kv_cache_quant(self, quant_config, prefix) dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( self.head_size, dtype, kv_cache_dtype, block_size, use_mla=True, use_sparse=use_sparse, num_heads=self.num_heads, ) if ( cache_config is not None and cache_config.enable_prefix_caching and vllm_is_batch_invariant() and ( self.attn_backend.get_name() == "TRITON_MLA" or self.attn_backend.get_name() == "FLASHINFER" ) ): logger.warning_once( "Disabling prefix caching for TRITON_MLA / FLASHINFER " "with batch invariance, as it is not yet supported.", scope="local", ) cache_config.enable_prefix_caching = False impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) self.impl = impl_cls( num_heads=self.num_heads, head_size=self.head_size, scale=self.scale, num_kv_heads=1, alibi_slopes=None, sliding_window=None, kv_cache_dtype=self.kv_cache_dtype, logits_soft_cap=None, attn_type=AttentionType.DECODER, kv_sharing_target_layer_name=None, # MLA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, v_head_dim=self.v_head_dim, kv_b_proj=kv_b_proj, indexer=indexer, **extra_impl_args, ) self.q_pad_num_heads = getattr(self.impl, "q_pad_num_heads", None) self.use_direct_call = True compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self self.kv_cache = [ torch.tensor([]) for _ in range( get_current_vllm_config().parallel_config.pipeline_parallel_size ) ] if envs.VLLM_USE_INT8_MLA: self.kv_cache_scale = [ torch.tensor([]) for _ in range(get_current_vllm_config( ).parallel_config.pipeline_parallel_size) ] self.is_int8_mla = envs.VLLM_USE_INT8_MLA self.use_sparse = use_sparse # Initialize q/k/v range constants. self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported self.is_aiter_triton_fp4_bmm_enabled = ( rocm_aiter_ops.is_fp4bmm_enabled() and self.kv_b_proj.weight.dtype == torch.bfloat16 ) # Attributes for forward_impl method self.chunked_prefill_workspace_size = ( MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( get_current_vllm_config() ) ) self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( static=True, group_shape=GroupShape.PER_TENSOR, compile_native=True, ) def forward( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, position: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: optional_args = {} if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) if self.use_direct_call: def direct_forward(layer_name: str, output_shape: torch.Size | None): forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] if self.is_int8_mla: optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine] if self.attn_backend.accept_output_buffer: output_shape = (output_shape if output_shape is not None else q.shape) output = torch.zeros(output_shape, dtype=q.dtype, device=q.device) output = self.forward_impl( q, kv_c_normed, k_pe, self_kv_cache, attn_metadata, output=output, positions=position, **optional_args ) return output else: return self.forward_impl( q, kv_c_normed, k_pe, self_kv_cache, attn_metadata, position=position ) return maybe_transfer_kv_layer(direct_forward)(self.layer_name, output_shape) else: kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( kv_c_normed, k_pe, self.layer_name, self.kv_cache_dtype, self._k_scale, ) if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) torch.ops.vllm.unified_mla_attention_with_output( q, kv_c_normed, k_pe, output, self.layer_name, kv_cache_dummy_dep=kv_cache_dummy_dep, ) return output else: return torch.ops.vllm.unified_mla_attention( q, kv_c_normed, k_pe, self.layer_name, kv_cache_dummy_dep=kv_cache_dummy_dep, ) def forward_impl( self, q: torch.Tensor, k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: "MLACommonMetadata", output: torch.Tensor | None = None, positions: torch.Tensor | None = None, kv_cache_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." assert positions is not None, "positions tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported for MLA" ) if attn_metadata is None: # During the profile run try to simulate to worse case output size # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` # since this can be large _ = torch.empty( ( self.chunked_prefill_workspace_size, self.num_heads, self.qk_nope_head_dim + self.v_head_dim, ), device=k_c_normed.device, dtype=k_c_normed.dtype, ) # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. output = torch.empty(output.shape[0], self.v_head_dim * self.num_heads, device=q.device, dtype=q.dtype) return output if self.impl.dcp_world_size == -1: self.impl.dcp_world_size = get_dcp_group().world_size fp8_attention = self.kv_cache_dtype.startswith("fp8") is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl) if not is_sparse_impl: assert ( attn_metadata.num_decodes is not None and attn_metadata.num_prefills is not None and attn_metadata.num_decode_tokens is not None ) has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens decode_q = q[:num_decode_tokens] k_pe = k_pe.unsqueeze(1) prefill_q = q[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] prefill_k = torch.empty_like(prefill_q) # write the latent and rope to kv cache write_kv_cache = (None, None) if kv_cache.numel() > 0: if has_decode: decode_q_pe, decode_k_pe = self.rotary_emb(positions[:num_decode_tokens], decode_q[..., self.qk_nope_head_dim:], k_pe[:num_decode_tokens],) if envs.VLLM_USE_INT8_MLA: k_c_normed_int8, k_c_normed_scale,_ = ops.scaled_int8_quant(k_c_normed[:num_decode_tokens]) decode_k_pe_int8, decode_k_pe_scale,_ = ops.scaled_int8_quant(decode_k_pe.contiguous()) ops.concat_and_cache_mla_int8( kv_c_int8 = k_c_normed_int8, kv_c_scale = k_c_normed_scale[...,0], k_pe_int8 = decode_k_pe_int8, k_pe_scale = decode_k_pe_scale[...,0].view(-1,decode_k_pe_int8.shape[-2]), kv_cache = kv_cache, kv_cache_scale = kv_cache_scale, slot_mapping = attn_metadata.slot_mapping.flatten()[:num_decode_tokens], kv_cache_dtype=self.kv_cache_dtype, scale=self._k_scale, ) else: if self.impl.dcp_world_size > 1 or self.use_spec_decode: ops.concat_and_cache_mla( k_c_normed[:num_decode_tokens], decode_k_pe, kv_cache, attn_metadata.slot_mapping.flatten()[:num_decode_tokens], kv_cache_dtype=self.kv_cache_dtype, scale=self._k_scale, ) else: write_kv_cache = (k_c_normed[:num_decode_tokens], decode_k_pe) if has_prefill: ixf_ops.mla_rope(positions[num_decode_tokens:], prefill_q[..., self.qk_nope_head_dim:], prefill_k_pe.squeeze(1), prefill_k[...,self.qk_nope_head_dim:], self.rotary_emb.cos_sin_cache) if envs.VLLM_USE_INT8_MLA: prefill_k_c_normed_int8, prefill_k_c_normed_scale,_ = ops.scaled_int8_quant(prefill_k_c_normed) prefill_k_pe_int8, prefill_k_pe_scale,_ = ops.scaled_int8_quant(prefill_k[...,self.qk_nope_head_dim:].contiguous()) ops.concat_and_cache_mla_int8( prefill_k_c_normed_int8, prefill_k_c_normed_scale[...,0], prefill_k_pe_int8, prefill_k_pe_scale[...,0].view(-1,prefill_k_pe_int8.shape[-2]), kv_cache, kv_cache_scale, attn_metadata.slot_mapping.flatten()[num_decode_tokens:], self.kv_cache_dtype, self._k_scale, ) else: ops.concat_and_cache_mla( prefill_k_c_normed, prefill_k[...,self.qk_nope_head_dim:], kv_cache, attn_metadata.slot_mapping.flatten()[num_decode_tokens:], kv_cache_dtype=self.kv_cache_dtype, scale=self._k_scale, ) output = torch.empty(output.shape[0], self.num_heads, self.v_head_dim, device=q.device, dtype=q.dtype) if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla": kv_cache = kv_cache.view(current_platform.fp8_dtype()) # Sparse MLA impls only support forward_mqa (decode-style attention) if has_prefill: output[num_decode_tokens:] = self.impl.forward_mha( prefill_q, prefill_k_c_normed, prefill_k, kv_cache, attn_metadata, kv_c_and_k_pe_cache_scale=kv_cache_scale ) if has_decode: # For sparse impl, we always use forward_mqa for all tokens # For non-sparse impl, we only use forward_mqa for decode tokens assert attn_metadata.decode is not None mqa_q = decode_q attn_out, lse = self.impl.forward_mqa( mqa_q[..., :self.qk_nope_head_dim], decode_q_pe, kv_cache, attn_metadata, *write_kv_cache, kv_c_and_k_pe_cache_scale=kv_cache_scale) if self.impl.dcp_world_size > 1: assert lse is not None attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) output[:num_decode_tokens] = self.impl._v_up_proj(attn_out) else: assert lse is None output[:num_decode_tokens] = attn_out return output.view(output.shape[0], self.v_head_dim * self.num_heads) else: num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs k_pe = k_pe.unsqueeze(1) q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] positions = positions[:num_actual_toks, ...] num_mqa_tokens = q.size(0) if num_mqa_tokens > 0: q = q[:num_mqa_tokens] q_nope, q_pe = q.split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) q_pe, k_pe = self.rotary_emb(positions[:num_actual_toks], q_pe, k_pe) q_nope = self.impl._k_up_proj(q_nope) q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank) q = torch.cat([q_nope, q_pe], dim=-1) if kv_cache.numel() > 0: ops.concat_and_cache_mla( k_c_normed, k_pe, kv_cache, attn_metadata.slot_mapping.flatten(), kv_cache_dtype=self.kv_cache_dtype, scale=self._k_scale, ) attn_out= self.impl.forward_mqa(q, kv_cache, attn_metadata, self) output = torch.empty(output.shape[0], self.num_heads, self.v_head_dim, device=q.device, dtype=q.dtype) output[:num_actual_toks] = self.impl._v_up_proj(attn_out) return output.view(output.shape[0], self.v_head_dim * self.num_heads) def process_weights_after_loading(self, act_dtype: torch.dtype): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) # If we should not load quant weights, we initialize the scales to 1.0 # as the default value. See [Note: Register q/k/v/prob scales in state dict] # for more details. quant_method = ( self.quant_config.get_quant_method(self, prefix=self.layer_name) if self.quant_config else None ) if not should_load_quant_weights(quant_method): set_default_quant_scales(self, register_buffer=False) def calc_kv_scales( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor ) -> None: """Optional scale calculation for MLA inputs. Mirrors Attention.calc_kv_scales. Not all MLA backends require this """ # Use safe defaults if ranges are not present q_range = getattr(self, "q_range", torch.tensor(1.0)) k_range = getattr(self, "k_range", torch.tensor(1.0)) v_range = getattr(self, "v_range", torch.tensor(1.0)) self._q_scale.copy_(torch.abs(q).max() / q_range) # kv_c_normed is the compressed KV representation; use it for k/v kv_abs_max = torch.abs(kv_c_normed).max() self._k_scale.copy_(kv_abs_max / k_range) self._v_scale.copy_(kv_abs_max / v_range) self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() self.calculate_kv_scales = False def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: kv_cache_dtype = kv_cache_dtype_str_to_dtype( self.kv_cache_dtype, vllm_config.model_config ) return MLAAttentionSpec( block_size=vllm_config.cache_config.block_size, num_kv_heads=1, head_size=self.head_size, dtype=kv_cache_dtype, cache_dtype_str=vllm_config.cache_config.cache_dtype, ) @maybe_transfer_kv_layer def unified_mla_attention( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: # kv_cache_dummy_dep is not used but accepting it creates a data dependency # that ensures torch.compile preserves ordering between KV cache update and # attention forward. del kv_cache_dummy_dep attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) return output def unified_mla_attention_fake( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q).contiguous() direct_register_custom_op( op_name="unified_mla_attention", op_func=unified_mla_attention, mutates_args=[], fake_impl=unified_mla_attention_fake, dispatch_key=current_platform.dispatch_key, ) def unified_mla_kv_cache_update( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, kv_cache_dtype: str, k_scale: torch.Tensor, ) -> torch.Tensor: """ Returns a dummy that is passed to unified_attention to signal a side effect and the data dependency between them to ensure torch.compile preserves ordering. """ forward_context = get_forward_context() attn_layer = forward_context.no_compile_layers[layer_name] kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " ) layer_slot_mapping = slot_mapping.get(layer_name) if layer_slot_mapping is not None: attn_layer.impl.do_kv_cache_update( kv_c_normed, k_pe, kv_cache, layer_slot_mapping, kv_cache_dtype, k_scale, ) return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) def unified_mla_kv_cache_update_fake( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, kv_cache_dtype: str, k_scale: torch.Tensor, ) -> torch.Tensor: return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) direct_register_custom_op( op_name="unified_mla_kv_cache_update", op_func=unified_mla_kv_cache_update, fake_impl=unified_mla_kv_cache_update_fake, ) @maybe_transfer_kv_layer def unified_mla_attention_with_output( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: # kv_cache_dummy_dep is not used but accepting it creates a data dependency # that ensures torch.compile preserves ordering between KV cache update and # attention forward. del kv_cache_dummy_dep attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) layer.forward_impl( q, kv_c_normed, k_pe, kv_cache, attn_metadata, output=output, output_scale=output_scale, output_block_scale=output_block_scale, ) def unified_mla_attention_with_output_fake( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: return direct_register_custom_op( op_name="unified_mla_attention_with_output", op_func=unified_mla_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_mla_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) class QueryLenSupport(Enum): """Defines the level of query length support for an attention backend's decode pipeline. - SINGLE_ONLY: Decode pipeline only supports single-token queries (query_len=1) - UNIFORM: Decode pipeline supports uniform multi-token queries (all requests must have same query_len > 1) - VARLEN: Decode pipeline supports variable-length queries (mixed query lengths in same batch) """ SINGLE_ONLY = "single_only" UNIFORM = "uniform" VARLEN = "varlen" try: from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func,merge_attn_states is_vllm_fa = True except ImportError: is_vllm_fa = False flash_attn_varlen_func = None # type: ignore[assignment] # On ROCm, vllm_flash_attn is not available, try upstream flash_attn instead. # On CUDA, vllm_flash_attn should always be available (built with vLLM), # so we don't attempt the fallback there. if current_platform.is_rocm(): try: from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] except ImportError: logger.debug( "flash_attn not available on ROCm; " "MLA models using TRITON_MLA will require flash_attn. " "AITER_MLA backends use aiter kernels instead." ) def dynamic_per_batched_tensor_quant( x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn ): DTYPE_MAX = torch.finfo(dtype).max min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) scale = DTYPE_MAX / amax x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() logger = init_logger(__name__) @CustomOp.register( "mla_decode_concat_quant_fp8", dynamic_arg_dims={"decode_ql_nope": 0, "decode_q_pe": 0}, ) class _DecodeConcatQuantFP8(QuantFP8): """ QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before quantization. When disabled, forward_native is compiled via torch.compile, fusing cat/reshape/quant/view together. """ def _make_forward(quant_fn): # noqa: N805 """Factory to create forward methods that concat before quantization.""" def forward( self, decode_ql_nope: torch.Tensor, decode_q_pe: torch.Tensor, scale: torch.Tensor, scale_ub: torch.Tensor | None = None, ) -> torch.Tensor: decode_q0 = torch.cat((decode_ql_nope, decode_q_pe), dim=-1) decode_q_flat = decode_q0.reshape(decode_q0.shape[0], -1) decode_q, _ = quant_fn(self, decode_q_flat, scale, scale_ub) return decode_q.view(decode_q0.shape) return forward forward_native = _make_forward(QuantFP8.forward_native) # type: ignore[arg-type] forward_cuda = _make_forward(QuantFP8.forward_cuda) # type: ignore[arg-type] forward_hip = _make_forward(QuantFP8.forward_hip) # type: ignore[arg-type] CUDNN_WORKSPACE_SIZE = 12800 from vllm import envs from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize import ixformer.inference.functions as ixf_ops import numpy as np class MLACommonBackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod def get_name() -> str: return "TRITON_MLA" @staticmethod def get_builder_cls() -> type["MLACommonMetadataBuilder"]: return MLACommonMetadataBuilder @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) @staticmethod def get_kv_cache_stride_order( include_num_layers_dimension: bool = False, ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. # (num_blocks, num_layers, block_size, head_size) return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @classmethod def is_mla(cls) -> bool: return True @dataclass class MLACommonPrefillMetadata: """Prefill Specific Metadata""" @dataclass class ChunkedContextMetadata: # New for MLA (compared to FlashAttention) # For handling chunked prefill cu_seq_lens: torch.Tensor starts: torch.Tensor seq_tot: list[int] max_seq_lens: list[int] seq_lens: torch.Tensor workspace: torch.Tensor token_to_seq: torch.Tensor chunk_total_token: list[int] # for mla DCP padded_local_chunk_seq_lens: list[list[int]] | None = None local_context_lens_allranks: list[list[int]] | None = None padded_local_cu_seq_lens: torch.Tensor | None = None cu_seq_lens_lst: list[list[int]] | None = None chunk_size: int | None = None block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int chunked_context: ChunkedContextMetadata | None = None query_seq_lens: torch.Tensor | None = None workspace_buffer: torch.Tensor | None = None q_data_type: torch.dtype | None = None output_dtype: torch.dtype | None = None @dataclass class FlashInferPrefillMetadata(MLACommonPrefillMetadata): prefill_main: "BatchPrefillWithRaggedKVCacheWrapper | None" = None prefill_chunks: "list[BatchPrefillWithRaggedKVCacheWrapper]" = field( default_factory=list ) @dataclass class CudnnPrefillMetadata(MLACommonPrefillMetadata): class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): seq_lens: torch.Tensor cudnn_workspace: torch.Tensor | None = None @dataclass class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor dcp_tot_seq_lens: torch.Tensor | None max_decode_seq_len: int use_cuda_graph: bool D = TypeVar("D", bound=MLACommonDecodeMetadata) @dataclass class MLACommonMetadata(AttentionMetadata, Generic[D]): """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to understand this class """ # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| num_reqs: int max_query_len: int max_seq_len: int num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor slot_mapping: torch.Tensor # New for MLA (compared to FlashAttention) # For handling prefill decode split num_decodes: int num_decode_tokens: int num_prefills: int # The dimension of the attention heads head_dim: int | None = None decode: D | None = None prefill: ( MLACommonPrefillMetadata | FlashInferPrefillMetadata | CudnnPrefillMetadata | None ) = None def __post_init__(self): if self.head_dim is not None and not MLACommonBackend.supports_head_size( self.head_dim ): raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") M = TypeVar("M", bound=MLACommonMetadata) A = TypeVar("A", bound=AttentionMetadata) def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool: # Check if model has DeepSeek R1 compatible MLA dimensions: # qk_nope_head_dim = 128, qk_rope_head_dim = 64, v_head_dim = 128 # which results in query/key head dim = 192. if vllm_config.model_config is None: return False hf_text_config = vllm_config.model_config.hf_text_config qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) qk_rope_head_dim = getattr(hf_text_config, "qk_rope_head_dim", 1) v_head_dim = getattr(hf_text_config, "v_head_dim", 1) return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128 @functools.cache def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() if not ( not vllm_config.attention_config.disable_flashinfer_prefill and has_flashinfer() and not vllm_config.attention_config.use_cudnn_prefill and current_platform.is_device_capability_family(100) ): return False return is_deepseek_r1_mla_compatible(vllm_config) @functools.cache def use_cudnn_prefill() -> bool: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() return ( has_flashinfer() and vllm_config.attention_config.use_cudnn_prefill and current_platform.is_device_capability_family(100) and has_nvidia_artifactory() ) @functools.cache def use_trtllm_ragged_deepseek_prefill() -> bool: """Check if TRT-LLM ragged DeepSeek prefill should be used.""" from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() if not ( has_flashinfer() and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill and current_platform.is_device_capability_family(100) ): return False return is_deepseek_r1_mla_compatible(vllm_config) @dataclass class MLADims: q_lora_rank: int | None kv_lora_rank: int qk_nope_head_dim: int qk_rope_head_dim: int v_head_dim: int def get_mla_dims(model_config: ModelConfig) -> MLADims: hf_text_config = model_config.hf_text_config return MLADims( q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), kv_lora_rank=hf_text_config.kv_lora_rank, qk_nope_head_dim=hf_text_config.qk_nope_head_dim, qk_rope_head_dim=hf_text_config.qk_rope_head_dim, v_head_dim=hf_text_config.v_head_dim, ) @functools.cache def backend_supports_prefill_query_quantization() -> bool: """Check if the selected MLA backend supports prefill query quantization. Currently supported backends: - FlashInfer prefill - TRT-LLM ragged DeepSeek prefill Not supported: - cuDNN Prefill - FlashAttention - Non-GB200 devices (FP8 prefill requires device capability 100) """ # FP8 prefill query quantization requires GB200 (device capability 100) # for the necessary FP8 kernels at the moment. if not current_platform.is_device_capability_family(100): return False return use_flashinfer_prefill() or use_trtllm_ragged_deepseek_prefill() class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ # Defines the level of query length support for this backend. # - SINGLE_ONLY: Only single-token queries (no spec decode support) # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths) # - VARLEN: Supports variable-length queries (spec decode with mixed lengths) # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when # speculative decoding is enabled. query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM # The threshold for reordering the batch into decode and prefill requests. # If > 1, the batch will be reordered such that requests with # query length <= threshold are classified as decode requests. # Use `query_len_support` (above) to set this automatically # when speculative decoding is enabled. reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH @staticmethod def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config model_config = vllm_config.model_config chunked_prefill_workspace_size = min( # Try for 8 full length request or at least 4 pages per-request max( 8 * model_config.max_model_len, 4 * scheduler_config.max_num_seqs * cache_config.block_size, ), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, # which would result in the workspace being: # 2*(576)*(64*1024) = 144mb # (assuming 576 MLA head dim, and fp16) # which would result in up-projected context being # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) 64 * 1024, ) # Enforce that we enough for at least 1 page per request chunked_prefill_workspace_size = max( chunked_prefill_workspace_size, scheduler_config.max_num_seqs * cache_config.block_size, ) return chunked_prefill_workspace_size @staticmethod def determine_prefill_query_data_type( vllm_config: VllmConfig, model_dtype: torch.dtype, ) -> torch.dtype: """ Determine the query data type for prefill queries. Return FP8 dtype if cache is FP8 and prefill query quantization is enabled, else model dtype. """ use_fp8 = ( vllm_config.cache_config.cache_dtype.startswith("fp8") and vllm_config.attention_config.use_prefill_query_quantization and backend_supports_prefill_query_quantization() ) if use_fp8: fp8_dtype = current_platform.fp8_dtype() logger.info_once( "FP8 prefill attention enabled: query data type is FP8", scope="local" ) return fp8_dtype elif vllm_config.attention_config.use_prefill_query_quantization: logger.info_once( "Unable to perform FP8 prefill attention when" " use_prefill_query_quantization is enabled. Please" " ensure that --kv-cache-dtype is set to fp8 and your prefill" " backend is compatible with FP8 attention.", scope="local", ) return model_dtype return model_dtype def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, metadata_cls: type[M] | None = None, supports_dcp_with_varlen: bool = False, ): self.metadata_cls = ( metadata_cls if metadata_cls is not None else MLACommonMetadata ) self.kv_cache_spec = kv_cache_spec scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.compilation_config = vllm_config.compilation_config self.decode_use_graph = ( vllm_config.compilation_config.cudagraph_mode.decode_use_graph() ) self.use_full_cuda_graph = ( vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() ) self.vllm_config = vllm_config self.device = device self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() self.kv_cache_spec = kv_cache_spec self.q_data_type = self.determine_prefill_query_data_type( vllm_config, self.model_config.dtype ) try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size # Don't try to access the runner on AMD if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size self.chunked_prefill_workspace_size = ( self.determine_chunked_prefill_workspace_size(vllm_config) ) if self.dcp_world_size > 1: # Note(hc): The local kvcache is incomplete when DCP is triggered, # an additional kvcache allgather across the DCP group is therefore # required, so the workspace has to be enlarged by 1/DCP relative # to the original TP allocation. assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0 self.chunked_prefill_workspace = torch.empty( ( self.chunked_prefill_workspace_size + self.chunked_prefill_workspace_size // self.dcp_world_size, self.model_config.get_head_size(), ), dtype=self.model_config.dtype, device=device, ) else: self.chunked_prefill_workspace = torch.empty( ( self.chunked_prefill_workspace_size, self.model_config.get_head_size(), ), dtype=self.q_data_type, device=device, ) self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill() self.prefill_metadata_cls = ( FlashInferPrefillMetadata if self._use_fi_prefill else CudnnPrefillMetadata if self._use_cudnn_prefill else MLACommonPrefillMetadata ) if self._use_fi_prefill: self._workspace_buffer = torch.empty( envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device, ) self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) # type: ignore[type-abstract] ) if self._use_trtllm_ragged_prefill: self._workspace_buffer = torch.empty( envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device, ) if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, dtype=torch.int8, device=device, ) supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY self._init_reorder_batch_threshold( self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen ) # Validate consistency between query_len_support and reorder_batch_threshold if self.query_len_support == QueryLenSupport.SINGLE_ONLY: assert self.reorder_batch_threshold == 1, ( f"reorder_batch_threshold must be 1 when query_len_support is " f"SINGLE_ONLY, got {self.reorder_batch_threshold}" ) # Spec-decode expands block_table / seq_lens via index_select; CUDA graph # replay expects stable tensor storage. Copy expanded rows into these # buffers when decode or full CUDA graph is enabled. self._spec_decode_expand_block_table_buf: torch.Tensor | None = None self._spec_decode_expand_seq_lens_buf: torch.Tensor | None = None def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc has_context = False if prefill.chunked_context is not None: chunked_context = prefill.chunked_context has_context = True if self._fi_prefill_main is None: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( self._workspace_buffer, "NHD", backend="cutlass" ) if has_context: num_chunks = chunked_context.cu_seq_lens.shape[0] # Allocate more prefill chunk wrappers if needed if len(self._fi_prefill_chunks) < num_chunks: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper for _ in range(len(self._fi_prefill_chunks), num_chunks): self._fi_prefill_chunks.append( BatchPrefillWithRaggedKVCacheWrapper( self._workspace_buffer, "NHD", backend="cutlass" ) ) assert num_chunks <= len(self._fi_prefill_chunks) # In MLA, the non-latent num_qo_heads == num_kv_heads num_qo_heads = self.num_heads num_kv_heads = num_qo_heads # Sanity: Verify that num_kv_heads == 1 since it is latent space assert self.kv_cache_spec.num_kv_heads == 1 # Get non-latent head_dim_qk and head_dim_vo head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim head_dim_vo = self.mla_dims.v_head_dim # For main run, qo_indptr == kv_indptr kv_indptr = qo_indptr.clone() # Prepare main prefill self._fi_prefill_main.plan( qo_indptr=qo_indptr, kv_indptr=kv_indptr, num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, causal=True, # This is main run sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.q_data_type, o_data_type=prefill.output_dtype, ) # Prepare context prefills if has_context: for i in range(num_chunks): kv_indptr_chunk = chunked_context.cu_seq_lens[i] self._fi_prefill_chunks[i].plan( qo_indptr=qo_indptr, kv_indptr=kv_indptr_chunk, num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, causal=False, # This is context run sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.q_data_type, o_data_type=prefill.output_dtype, ) prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks def _build_decode( self, block_table_tensor: torch.Tensor, seq_lens_device: torch.Tensor, max_seq_len: int, query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, dcp_tot_seq_lens_device: torch.Tensor | None, max_decode_seq_len: int, use_cuda_graph: bool ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, dcp_tot_seq_lens=dcp_tot_seq_lens_device, max_decode_seq_len=max_decode_seq_len, use_cuda_graph=use_cuda_graph, ) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( "MLA only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." ) assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m) def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> M: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.device block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens seq_lens_np = common_attn_metadata.seq_lens_np num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), ) ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens prefill_metadata = None if num_prefills > 0: num_computed_tokens_cpu = ( common_attn_metadata.compute_num_computed_tokens().cpu() ) reqs_start = num_decodes # prefill_start context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = ( query_start_loc[reqs_start:] - query_start_loc[reqs_start] ) chunked_context_metadata = None if max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code # currently we allocate an equal amount of workspace for each # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths max_context_chunk = ( self.chunked_prefill_workspace_size // num_prefills_with_context_cpu ) if self.aot_schedule: # align max_context_chunk to page_size by rounding down, # currently the `gather_and_maybe_dequant_cache` kernel # cannot handle `context_chunk_starts` that are not aligned # to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) # if `max_context_chunk = 256`, `num_chunks = 3`, and # `num_prefills_with_context = 4`, create a tensor that looks # like # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] # Note(simon): this is done in CPU because of downstream's # of `to_list`. chunk_starts = ( torch.arange(num_chunks, dtype=torch.int32) .unsqueeze(1) .expand(-1, num_prefills) * max_context_chunk ) chunk_ends = torch.min( context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk ) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) cu_seq_lens_cpu = torch.zeros( num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True ) torch.cumsum( chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 ) chunk_total_token = cu_seq_lens_cpu[:, -1] max_token_num_over_chunk = chunk_total_token.max().item() token_to_seq_tensor_cpu = torch.zeros( [num_chunks, max_token_num_over_chunk], dtype=torch.int32 ) range_idx = torch.arange(num_prefills, dtype=torch.int32) for i in range(num_chunks): chunk_token_to_seq_tensor = torch.repeat_interleave( range_idx, chunk_seq_lens[i] ) chunk_len = chunk_token_to_seq_tensor.shape[0] token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor if self.dcp_world_size > 1: local_context_lens_allranks = get_dcp_local_seq_lens( context_lens_cpu, self.dcp_world_size, None, self.dcp_local_block_size, ) # Note(qcs): The max local context lengths # padded to `dcp_local_block_size`. padded_local_context_lens_cpu: torch.Tensor = ( cdiv( context_lens_cpu, self.dcp_virtual_block_size, ) * self.dcp_local_block_size ) # Note(hc): The above max_context_chunk already enforces # block_size alignment, DCP just need the block_size can # be divisible by dcp_world_size, because DCP use # cp_gather_cache which not require `cp_chunk_starts` # aligned to page_size. assert max_context_chunk % self.dcp_world_size == 0 padded_local_max_context_chunk_across_ranks = ( cdiv( max_context_chunk, self.dcp_virtual_block_size, ) * self.dcp_local_block_size ) local_chunk_starts = ( torch.arange(num_chunks, dtype=torch.int32) .unsqueeze(1) .expand(-1, num_prefills) * padded_local_max_context_chunk_across_ranks ) local_chunk_ends = torch.min( padded_local_context_lens_cpu.unsqueeze(0), local_chunk_starts + padded_local_max_context_chunk_across_ranks, ) padded_local_chunk_seq_lens = ( local_chunk_ends - local_chunk_starts ).clamp(min=0) padded_local_cu_chunk_seq_lens_cpu = torch.zeros( num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True ) torch.cumsum( padded_local_chunk_seq_lens, dim=1, out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], dtype=torch.int32, ) chunked_context_metadata_cls = ( CudnnPrefillMetadata.ChunkedContextMetadata if self._use_cudnn_prefill else MLACommonPrefillMetadata.ChunkedContextMetadata ) if self.dcp_world_size > 1: chunked_context_metadata = chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=local_chunk_starts.to(device, non_blocking=True), seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), seq_lens=chunk_seq_lens, token_to_seq=token_to_seq_tensor_cpu.to( device, non_blocking=True ), chunk_total_token=chunk_total_token.tolist(), workspace=self.chunked_prefill_workspace, padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), local_context_lens_allranks=local_context_lens_allranks.tolist(), padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to( device, non_blocking=True ), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), chunk_size=padded_local_max_context_chunk_across_ranks, ) else: chunked_context_metadata = chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), seq_lens=chunk_seq_lens, token_to_seq=token_to_seq_tensor_cpu.to( device, non_blocking=True ), chunk_total_token=chunk_total_token, workspace=self.chunked_prefill_workspace, ) if self._use_cudnn_prefill: chunked_context_metadata.seq_lens = chunk_seq_lens assert ( max(chunked_context_metadata.max_seq_lens) <= self.chunked_prefill_workspace_size ) prefill_metadata = self.prefill_metadata_cls( block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, output_dtype=self.model_config.dtype, q_data_type=self.q_data_type, ) if self._use_cudnn_prefill: assert isinstance(prefill_metadata, CudnnPrefillMetadata) prefill_metadata.query_seq_lens = ( prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] ) prefill_metadata.cudnn_workspace = self.cudnn_workspace if self._use_trtllm_ragged_prefill: prefill_metadata.query_seq_lens = ( prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] ) prefill_metadata.workspace_buffer = self._workspace_buffer decode_metadata = None if num_decodes > 0: dcp_tot_seq_lens_device = None if self.dcp_world_size > 1: dcp_tot_seq_lens_device = seq_lens[:num_decodes] seq_lens = dcp_local_seq_lens # After DCP distribution, the maximum number of tokens for any rank is # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, # and I is cp_kv_cache_interleave_size. # This eliminates GPU->CPU sync while minimizing workspace # over-allocation. num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size max_seq_len = ( (max_seq_len + num_partitions - 1) // num_partitions ) * self.cp_kv_cache_interleave_size decode_block_table = block_table_tensor[:num_decodes, ...] decode_seq_lens_np = seq_lens_np[:num_decodes] decode_seq_lens_device = seq_lens[:num_decodes] decode_query_start_loc_cpu = query_start_loc_cpu[: num_decodes + 1] decode_query_start_loc = query_start_loc[: num_decodes + 1] max_decode_seq_len = np.max(decode_seq_lens_np).item() speculative_config = self.vllm_config.speculative_config use_spec_decode = ( speculative_config is not None and speculative_config.num_speculative_tokens is not None and speculative_config.num_speculative_tokens > 0 ) # For non-DCP speculative decode, expand decode metadata from # request-level rows to token-level rows so block_table/seq_lens # align with num_decode_tokens. if ( use_spec_decode and self.dcp_world_size == 1 and num_decode_tokens > num_decodes ): decode_lens_cpu = ( decode_query_start_loc_cpu[1:] - decode_query_start_loc_cpu[:-1] ) decode_lens_device_long = decode_lens_cpu.to( device=decode_block_table.device, dtype=torch.long ) req_ids = torch.repeat_interleave( torch.arange( num_decodes, device=decode_block_table.device, dtype=torch.long ), decode_lens_device_long, ) if req_ids.numel() != num_decode_tokens: pass else: expanded_block_table = decode_block_table.index_select( 0, req_ids ) decode_query_start_loc_long = decode_query_start_loc.to( device=expanded_block_table.device, dtype=torch.long ) token_starts = torch.repeat_interleave( decode_query_start_loc_long[:-1], decode_lens_device_long ) token_pos_in_req = ( torch.arange( num_decode_tokens, device=expanded_block_table.device, dtype=torch.long, ) - token_starts ) decode_seq_lens_device_long = decode_seq_lens_device.to( device=decode_block_table.device, dtype=torch.long ) context_lens_long = ( decode_seq_lens_device_long - decode_lens_device_long ) decode_seq_lens_device_long = ( context_lens_long.index_select(0, req_ids) + token_pos_in_req + 1 ) decode_seq_lens_device = decode_seq_lens_device_long.to( dtype=seq_lens.dtype ) if self.decode_use_graph or self.use_full_cuda_graph: n = expanded_block_table.shape[0] ncols = expanded_block_table.shape[1] max_rows = ( self.vllm_config.scheduler_config.max_num_batched_tokens ) assert n <= max_rows, ( f"spec-decode expanded decode rows {n} > " f"max_num_batched_tokens {max_rows}" ) if ( self._spec_decode_expand_block_table_buf is None or self._spec_decode_expand_block_table_buf.shape[0] < max_rows or self._spec_decode_expand_block_table_buf.shape[1] != ncols ): self._spec_decode_expand_block_table_buf = torch.zeros( max_rows, ncols, dtype=block_table_tensor.dtype, device=self.device, ) self._spec_decode_expand_seq_lens_buf = torch.zeros( max_rows, dtype=seq_lens.dtype, device=self.device, ) self._spec_decode_expand_block_table_buf[:n].copy_( expanded_block_table ) self._spec_decode_expand_seq_lens_buf[:n].copy_( decode_seq_lens_device ) decode_block_table = self._spec_decode_expand_block_table_buf[ :n ] decode_seq_lens_device = ( self._spec_decode_expand_seq_lens_buf[:n] ) else: decode_block_table = expanded_block_table decode_seq_lens_cpu = decode_seq_lens_device.to( device="cpu", dtype=seq_lens.dtype ) max_decode_seq_len = torch.max(decode_seq_lens_cpu).item() decode_metadata = self._build_decode( block_table_tensor=decode_block_table, seq_lens_device=decode_seq_lens_device, max_seq_len=max_seq_len, query_start_loc_cpu=decode_query_start_loc_cpu, query_start_loc_device=decode_query_start_loc, num_decode_tokens=num_decode_tokens, dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, max_decode_seq_len=max_decode_seq_len, use_cuda_graph=num_prefills == 0 and self.decode_use_graph, ) attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, max_seq_len=max_seq_len, num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, head_dim=self.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, prefill=prefill_metadata, decode=decode_metadata, ) if self._use_fi_prefill and num_prefills > 0: assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata) self._build_fi_prefill_wrappers(attn_metadata.prefill) return attn_metadata def reorg_kvcache( allgatered_kv_c_normed: torch.Tensor, allgatered_k_pe: torch.Tensor, padded_local_chunk_seq_lens_lst: list[int], local_context_lens_allranks: list[list[int]], sum_seq_len: int, max_seq_len: int, chunk_size: int, chunk_idx: int, toks: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ reorg and unpad kvcache after cp local gather to tp layout for attn kernel. e.g. allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ..., T0_4, T0_5, pad, pad, T1_2, pad, ...] -> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5, T1_0, T1_1, T1_2, ...] Args: padded_local_chunk_seq_lens_lst: local chunk context lengths under current CP rank. local_context_lens_allranks: local context lengths on each CP rank. sum_seq_len: the sum of cp_chunk_seq_lens_lst. max_seq_len: the max value of cp_chunk_seq_lens_lst. chunk_size: the local padded max context chunk from chunked_context_metadata building. chunk_idx: chunk idx of chunked_prefill. toks: the number of tokens for local gather cache. """ kv_c_segments = [] k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 for padded_local_chunk_seq_len, local_context_lens in zip( padded_local_chunk_seq_lens_lst, local_context_lens_allranks ): cur_seq_len = 0 for rank, local_context_len in enumerate(local_context_lens): # Note(qcs): We split the context into multiple chunks, # depending on the size of the workspace. # local_context in dcp0: |-----------------| # local_context in dcp1: |--------------| # n*padded_local_chunk: |-----|-----|-----| # local_chunk_len in dcp1: |-----|-----|--| # so we need update the last chunk length in dcp1. local_chunk_len = min( max(0, local_context_len - chunk_idx * chunk_size), padded_local_chunk_seq_len, ) if local_chunk_len != 0: kv_c_segment = allgatered_kv_c_normed[ rank * toks + src_token_idx : rank * toks + src_token_idx + local_chunk_len ] k_pe_segment = allgatered_k_pe[ rank * toks + src_token_idx : rank * toks + src_token_idx + local_chunk_len ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) cur_seq_len += local_chunk_len max_seq_len_check = max(max_seq_len_check, cur_seq_len) src_token_idx += padded_local_chunk_seq_len reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) reorganized_k_pe = torch.cat(k_pe_segments, dim=0) assert reorganized_kv_c_normed.shape[0] == sum_seq_len assert reorganized_k_pe.shape[0] == sum_seq_len assert max_seq_len_check == max_seq_len return reorganized_kv_c_normed, reorganized_k_pe class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None, attn_type: str, kv_sharing_target_layer_name: str | None, # MLA Specific Arguments q_lora_rank: int | None, kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, indexer: object | None = None, q_pad_num_heads: int | None = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads self.supports_quant_query_input = True # Use flashinfer's optimized concat_mla_k kernel when available. # The kernel is optimized for DeepSeek V3 dimensions: # num_heads=128, nope_dim=128, rope_dim=64 self._use_flashinfer_concat_mla_k = ( has_flashinfer() and (self.num_heads == 128) and (self.qk_nope_head_dim == 128) and (self.qk_rope_head_dim == 64) ) if use_trtllm_ragged_deepseek_prefill(): logger.info_once( "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local" ) self._run_prefill_context_chunk = ( self._run_prefill_context_chunk_trtllm_ragged ) self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged self._pad_v = False elif use_flashinfer_prefill(): logger.info_once("Using FlashInfer prefill for MLA", scope="local") self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi self._pad_v = False elif use_cudnn_prefill(): logger.info_once("Using CUDNN prefill for MLA", scope="local") self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._pad_v = False else: # Use FlashAttention if flash_attn_varlen_func is None: raise RuntimeError( "MLA attention requires FlashAttention but it is not " "available. Please install flash_attn or use " "--attention-backend ROCM_AITER_MLA." ) logger.info_once("Using FlashAttention prefill for MLA", scope="local") self.positions = None self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa # Handle the differences between the flash_attn_varlen from # flash_attn and the one from vllm_flash_attn. The former is used on # RoCM and the latter has an additional parameter to control # FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func self.vllm_flash_attn_version = get_flash_attn_version() # if self.vllm_flash_attn_version is not None: # self.flash_attn_varlen_func = functools.partial( # flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version # ) # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do # not support different headdims # We don't need to pad V if we are on a hopper system with FA3 # self._pad_v = self.vllm_flash_attn_version is None or not ( # self.vllm_flash_attn_version == 3 # and current_platform.get_device_capability()[0] == 9 # ) self._pad_v = False self.dcp_world_size: int = -1 self.cp_kv_cache_interleave_size: int = ( get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size ) def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs ): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0 ) if is_vllm_fa: kwargs["return_softmax_lse"] = return_softmax_lse else: # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse if vllm_is_batch_invariant(): kwargs["num_splits"] = 1 attn_out = self.flash_attn_varlen_func( q=q, k=k, v=maybe_padded_v, softmax_scale=softmax_scale, **kwargs, ) # Unpack the output if there is multiple results lse = None if isinstance(attn_out, tuple): attn_out, lse = attn_out[0], attn_out[1] # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. if return_softmax_lse: return attn_out, lse return attn_out def _run_prefill_new_tokens_fa( self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse, out ): return self._flash_attn_varlen_diff_headdims( q=q, k=k, v=v, cu_seqlens_q=prefill.query_start_loc, cu_seqlens_k=prefill.query_start_loc, max_seqlen_q=prefill.max_query_len, max_seqlen_k=prefill.max_query_len, softmax_scale=self.scale, causal=True, return_softmax_lse=return_softmax_lse, out=out, ) def _run_prefill_new_tokens_fi( self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None ret = prefill.prefill_main.run( q=q, k=k, v=v, return_lse=return_softmax_lse, ) if isinstance(ret, tuple): return ret[0], ret[1].transpose(0, 1).contiguous() return ret def _run_prefill_new_tokens_cudnn( self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache output, lse = cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, v_cache=v, scale=self.scale, workspace_buffer=prefill.cudnn_workspace, max_token_per_sequence=prefill.max_query_len, max_sequence_kv=prefill.max_query_len, actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), causal=True, # Do not support False for now return_lse=True, # Indicates actual_seq_lens are on GPU or CPU. is_cuda_graph_compatible=True, ) if return_softmax_lse: return output, lse return output def _run_prefill_context_chunk_fa( self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v, out ): assert prefill.chunked_context is not None return self._flash_attn_varlen_diff_headdims( q=q, k=k, v=v, cu_seqlens_q=prefill.query_start_loc, cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx], max_seqlen_q=prefill.max_query_len, max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx], softmax_scale=self.scale, causal=False, # Context is unmasked return_softmax_lse=True, out=out, ) def _run_prefill_context_chunk_fi( self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): assert isinstance(prefill, FlashInferPrefillMetadata) attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, v=v, return_lse=True, ) # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() def _run_prefill_context_chunk_cudnn( self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.query_seq_lens is not None from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache return cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, v_cache=v, scale=self.scale, workspace_buffer=prefill.cudnn_workspace, max_token_per_sequence=prefill.max_query_len, max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view( -1, 1, 1, 1 ), causal=False, return_lse=True, # Indicates actual_seq_lens are on GPU or CPU. is_cuda_graph_compatible=True, ) def _run_prefill_new_tokens_trtllm_ragged( self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): """TRT-LLM ragged attention for new tokens (causal).""" from flashinfer.prefill import trtllm_ragged_attention_deepseek assert prefill.query_seq_lens is not None assert prefill.workspace_buffer is not None # allocate BF16 / FP16 output tensor for TRT-LLM ragged attention out = torch.empty( q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=prefill.output_dtype, ) ret = trtllm_ragged_attention_deepseek( query=q, key=k, value=v, workspace_buffer=prefill.workspace_buffer, seq_lens=prefill.query_seq_lens, max_q_len=prefill.max_query_len, max_kv_len=prefill.max_query_len, bmm1_scale=self.scale, bmm2_scale=1.0, o_sf_scale=1.0, batch_size=prefill.query_seq_lens.shape[0], window_left=-1, cum_seq_lens_q=prefill.query_start_loc, cum_seq_lens_kv=prefill.query_start_loc, enable_pdl=False, is_causal=True, return_lse=return_softmax_lse, out=out, ) if isinstance(ret, tuple): # Convert from (q_len, num_heads) to (num_heads, q_len) return ret[0], ret[1].transpose(0, 1).contiguous() return ret def _run_prefill_context_chunk_trtllm_ragged( self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): """TRT-LLM ragged attention for context chunks (non-causal).""" from flashinfer.prefill import trtllm_ragged_attention_deepseek assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.workspace_buffer is not None out = torch.zeros( q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=prefill.output_dtype, ) prefill.workspace_buffer.fill_(0) attn_out, lse = trtllm_ragged_attention_deepseek( query=q, key=k, value=v, workspace_buffer=prefill.workspace_buffer, seq_lens=prefill.chunked_context.seq_lens[chunk_idx], max_q_len=prefill.max_query_len, max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], bmm1_scale=self.scale, bmm2_scale=1.0, o_sf_scale=1.0, batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0], window_left=-1, cum_seq_lens_q=prefill.query_start_loc, cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx], enable_pdl=False, is_causal=False, return_lse=True, out=out, ) # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." ) def get_and_maybe_dequant_weights(layer: LinearBase): if layer.quant_method is not None and not isinstance( layer.quant_method, UnquantizedLinearMethod ): # we already prepare a dequant weight for GGUF, skip it. if layer.quant_method.__class__.__name__ == "GGUFLinearMethod": weight = layer.weight.clone() del layer.weight return weight if layer.quant_method.__class__.__name__ == "AWQMarlinLinearMethod": from ixformer.inference.functions import wui4a16 return wui4a16(None, layer.qweight, layer.scales, layer.qzeros, None, layer.quant_method.quant_config.group_size, only_return_weight=True) # for W8A8, we directly dequantize it here to avoiding quantization errors if hasattr(layer, "scheme") and layer.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" and not layer.scheme.is_static_input_scheme: quant_weight = layer.weight.T # output, input scales = layer.weight_scale return scaled_dequantize(quant_weight, scales, (1, -1), act_dtype) # NOTE: This should only be used offline, since it's O(N^3) eye = torch.eye( layer.input_size_per_partition, dtype=act_dtype, device=get_layer_weight(layer).device, ) dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T return layer.weight back_to_vllm = False weight_dtype = get_layer_weight(self.kv_b_proj).dtype # when use customize forward, we do not care the specific data types and shape, just reset the # _v_up_proj_and_o_proj and _q_proj_and_k_up_proj funs to get the correct results. if envs.VLLM_MLA_CUSTOMIZE: layer = self.kv_b_proj quant_method = layer.quant_method.__class__.__name__ if hasattr(layer, "scheme") and layer.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8": kv_b_proj_weight = self.kv_b_proj.weight # [i,o] kv_b_proj_weight_scale = self.kv_b_proj.weight_scale # [o, 1] kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim + self.v_head_dim, ) W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) kv_b_proj_weight_scale = kv_b_proj_weight_scale.view( self.num_heads, self.qk_nope_head_dim + self.v_head_dim, 1, ) W_UK_S, W_UV_S = kv_b_proj_weight_scale.split( [self.qk_nope_head_dim, self.v_head_dim], dim=1) # self.W_UK = W_UK # [kv_lora_rank, n, qk_nope] self.W_UK = ixf_ops.marlin_w8_weight_repack( weights=W_UK.permute(1, 2, 0).contiguous(), weight_format="int8", reformat="k16n16_grouped_n", ) self.W_UK_S = W_UK_S.contiguous() # [n, qk_nope, 1] # self.W_UV = W_UV # [kv_lora_rank, n, v_head_dim] self.W_UV = ixf_ops.marlin_w8_weight_repack( weights=W_UV.permute(1, 0, 2).contiguous(), weight_format="int8", reformat="k16n16", ) self.W_UV_S = W_UV_S.contiguous() # [n, v_head_dim, 1] def _v_up_proj_w8a8(self, x): # x: b, n, kv_lora # W_UV = (self.W_UV * self.W_UV_S.view(1, self.num_heads, self.v_head_dim)).to(x.dtype) # x = torch.einsum("bnl,lnv->bnv", x, W_UV) return ixf_ops.marlin_w8a16( x, self.W_UV, self.W_UV_S, group_size=-1, format="k16n16", batch_first=False, ) def _k_up_proj_w8a8(self, x_nope): # x_nope: b, n, qk_nope_head_dim # W_UK = (self.W_UK * self.W_UK_S.view(1, self.num_heads, self.qk_nope_head_dim)).to(x.dtype) # return torch.einsum("bnp,lnp->bnl", x_nope, W_UK).view(-1, self.num_heads, self.kv_lora_rank) return ixf_ops.marlin_w8a16( x_nope, self.W_UK, self.W_UK_S, group_size=-1, format="k16n16_grouped_n", batch_first=False, ) self._v_up_proj = functools.partial(_v_up_proj_w8a8, self) self._k_up_proj = functools.partial(_k_up_proj_w8a8, self) elif quant_method == "AWQMarlinLinearMethod": # ===== W_UK & W_UV use W4A16 ===== def split_4bit_matrix(qweight, scales, qzeros, num_heads, head_dim_1, head_dim_2): assert head_dim_1 % 8 == 0 assert head_dim_2 % 8 == 0 qweight = qweight.view(-1, num_heads, (head_dim_1 + head_dim_2)//8) w1, w2 = qweight.split([head_dim_1//8, head_dim_2//8], dim=-1) scales = scales.view(-1, num_heads, head_dim_1 + head_dim_2) s1, s2 = scales.split([head_dim_1, head_dim_2], dim=-1) qzeros = qzeros.view(-1, num_heads, (head_dim_1 + head_dim_2)//8) z1, z2 = qzeros.split([head_dim_1//8, head_dim_2//8], dim=-1) return (w1.contiguous().view(w1.shape[0], -1), s1.contiguous().view(s1.shape[0], -1), z1.contiguous().view(z1.shape[0], -1), w2.contiguous().view(w2.shape[0], -1), s2.contiguous().view(s2.shape[0], -1), z2.contiguous().view(z2.shape[0], -1)) # split W_UK and W_UV ( W_UK, W_UK_S, W_UK_Z, W_UV, W_UV_S, W_UV_Z ) = split_4bit_matrix( self.kv_b_proj.qweight, self.kv_b_proj.scales, self.kv_b_proj.qzeros, self.num_heads, self.qk_nope_head_dim, self.v_head_dim ) # repack W_UK W_UK = W_UK.reshape(self.kv_lora_rank, self.num_heads, -1).permute(1, 2, 0).contiguous() W_UK_S = W_UK_S.reshape(W_UK_S.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() W_UK_Z = W_UK_Z.reshape(W_UK_Z.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() ( self.W_UK, self.W_UK_S, self.W_UK_Z ) = ixf_ops.marlin_w4_weight_repack( weights = W_UK, scales = W_UK_S, zeros = W_UK_Z, weight_format = "gptq_grouped_n", reformat = "k16n32_grouped_n", ) # repack W_UV W_UV = W_UV.reshape(self.kv_lora_rank, self.num_heads, -1).permute(1, 0, 2).contiguous() W_UV_S = W_UV_S.reshape(W_UV_S.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() W_UV_Z = W_UV_Z.reshape(W_UV_Z.shape[0], self.num_heads, -1).permute(1, 0, 2).contiguous() ( self.W_UV, self.W_UV_S, self.W_UV_Z ) = ixf_ops.marlin_w4_weight_repack( weights = W_UV, scales = W_UV_S, zeros = W_UV_Z, weight_format = "awq", reformat = "k16n32", ) self.w4_group_size = self.kv_b_proj.quant_method.quant_config.group_size # ===== W_UK & W_UV use W16A16 ===== # kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T # kv_b_proj_weight = kv_b_proj_weight.view( # self.kv_lora_rank, # self.num_heads, # self.qk_nope_head_dim + self.v_head_dim, # ) # self.W_UK_fp16, self.W_UV_fp16 = kv_b_proj_weight.split( # [self.qk_nope_head_dim, self.v_head_dim], dim=-1) def _v_up_proj_w4a16(self, x): # W16A16 # res = torch.einsum("bnl,lnv->bnv", x, self.W_UV_fp16) # W4A16 return ixf_ops.marlin_w4a16( x, self.W_UV, self.W_UV_S, self.W_UV_Z, group_size=self.w4_group_size, format="k16n32", batch_first=False) def _k_up_proj_w4a16(self, x_nope): # W16A16 # return torch.einsum("bnp,lnp->bnl", x_nope, self.W_UK_fp16).view(-1, self.num_heads, self.kv_lora_rank), x_pe # W4A16 return ixf_ops.marlin_w4a16( x_nope, self.W_UK, self.W_UK_S, self.W_UK_Z, group_size=self.w4_group_size, format="k16n32_grouped_n", batch_first=False, ) self._v_up_proj = functools.partial(_v_up_proj_w4a16, self) self._k_up_proj = functools.partial(_k_up_proj_w4a16, self) else: print(f"Custom MLA for quant method: {layer.quant_method.__class__.__name__} is not supported, will use the vllm official impl.") back_to_vllm = True else: back_to_vllm = True if back_to_vllm: # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( f"{kv_b_proj_weight.shape=}, " f"{self.kv_lora_rank=}, " f"{self.num_heads=}, " f"{self.qk_nope_head_dim=}, " f"{self.v_head_dim=}") kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim + self.v_head_dim, ) W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) # # Convert from (L, N, V) to (N, L, V) # self.W_UV = W_UV.transpose(0, 1) # # Convert from (L, N, P) to (N, P, L) # self.W_UK_T = W_UK.permute(1, 2, 0) self.W_UV = W_UV self.W_UK = W_UK def _v_up_proj(self, x: torch.Tensor): # Convert from (B, N, L) to (N, B, L) # x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # out = out.view(-1, self.num_heads, self.v_head_dim) # if self.is_aiter_triton_fp4_bmm_enabled: # out = rocm_aiter_ops.batched_gemm_a16wfp4( # x, # self.W_V, # self.W_V_scale, # out, # transpose_bm=True, # prequant=True, # y_scale=None, # ) # x = out.view(-1, self.num_heads * self.v_head_dim) # elif self.is_aiter_triton_fp8_bmm_enabled: # # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) # x = rocm_aiter_ops.triton_fp8_bmm( # x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out # ) # else: # # Convert from (B, N * V) to (N, B, V) # out = out.transpose(0, 1) # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) # torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" # # Convert from (N, B, V) to (B, N * V) # out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) # # Adjust output buffer shape back to the original (B, N * V) # N, B, V = out.shape # out.resize_((B, N * V)) # out.copy_(out_new) # Copy result return torch.einsum("bnl,lnv->bnv", x, self.W_UV) def _k_up_proj(self, q_nope): # # Convert from (B, N, P) to (N, B, P) # q_nope = q_nope.transpose(0, 1) # # Multiply (N, B, P) x (N, P, L) -> (N, B, L) # ql_nope = torch.bmm(q_nope, self.W_UK_T) # # Convert from (N, B, L) to (B, N, L) # return ql_nope.transpose(0, 1), q_pe return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank) def _concat_k_nope_k_pe( self, k_nope: torch.Tensor, k_pe: torch.Tensor ) -> torch.Tensor: """ Efficiently concatenate k_nope and k_pe tensors along the last dimension. This function avoids the performance penalty of torch.cat with expanded non-contiguous tensors by pre-allocating the output and using direct copies. Args: k_nope: Tensor of shape [..., nope_dim] k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim] or [..., pe_dim] Returns: Tensor of shape [..., nope_dim + pe_dim] """ k = torch.empty( (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), dtype=k_nope.dtype, device=k_nope.device, ) if self._use_flashinfer_concat_mla_k: torch.ops.vllm.flashinfer_concat_mla_k(k, k_nope, k_pe) else: # Fallback: Direct copies with efficient broadcasting k[..., : k_nope.shape[-1]] = k_nope k[..., k_nope.shape[-1] :] = k_pe return k def _compute_prefill_context( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache_scale: torch.Tensor, attn_metadata: MLACommonMetadata, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype() output = None iters = len(prefill_metadata.chunked_context.seq_tot) workspace = prefill_metadata.chunked_context.workspace if use_fp8_prefill: q = q.to(prefill_metadata.q_data_type) for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] if envs.VLLM_USE_INT8_MLA: ops.gather_cache_int8( src_cache=kv_c_and_k_pe_cache, src_cache_scale=kv_c_and_k_pe_cache_scale, kv_lora_rank = self.kv_lora_rank, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, seq_starts=prefill_metadata.chunked_context.starts[i], ) else: ops.gather_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, seq_starts=prefill_metadata.chunked_context.starts[i], ) kv_c_normed = workspace[:toks][..., : self.kv_lora_rank].contiguous() k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) # To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope. if use_fp8_prefill: kv_nope = kv_nope.to(prefill_metadata.q_data_type) k_pe = k_pe.to(prefill_metadata.q_data_type) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, chunk_idx=i, q=q, k=k, v=v, out=attn_output, ) if output is None: output = attn_output output_lse = attn_softmax_lse else: output,output_lse = merge_attn_states( prefix_output=output, prefix_lse=output_lse, suffix_output=attn_output, suffix_lse=attn_softmax_lse, return_lse=True, ) return output, output_lse def _context_parallel_compute_prefill_context( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, dcp_world_size: int, ): assert k_scale is None, "DCP not support scaled kvcache now." assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None assert prefill_metadata.chunked_context.local_context_lens_allranks is not None assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None assert prefill_metadata.chunked_context.chunk_size is not None output = None iters = len(prefill_metadata.chunked_context.seq_tot) workspace = prefill_metadata.chunked_context.workspace for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] ops.cp_gather_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ i ], batch_size=attn_metadata.num_prefills, seq_starts=prefill_metadata.chunked_context.starts[i], ) # workspace # |------- N tokens --------|--------- N*dcp_size tokens ----------| # |<- use for loca_gather ->|<--------- use for allgather -------->| allgather_offset = workspace.shape[0] // (dcp_world_size + 1) assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] assert toks <= allgather_offset local_gathered_kvcache = workspace[:toks] cur_allgather_workspace = workspace[ allgather_offset : allgather_offset * (1 + dcp_world_size) ] assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] cur_allgather_kvcache.copy_( get_dcp_group().all_gather(local_gathered_kvcache, dim=0) ) assert ( cur_allgather_kvcache.shape[-1] == self.kv_lora_rank + self.qk_rope_head_dim ) allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze( 1 ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed, k_pe = reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[ i ], local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks, sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], chunk_size=prefill_metadata.chunked_context.chunk_size, chunk_idx=i, toks=toks, ) kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = self._concat_k_nope_k_pe(k_nope, k_pe) attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, chunk_idx=i, q=q, k=k, v=v, out=attn_output, ) if output is None: output = attn_output output_lse = attn_softmax_lse else: merge_attn_states( prefix_output=output, prefix_lse=output_lse, suffix_output=attn_output, suffix_lse=attn_softmax_lse, return_lse=True, ) return output, output_lse def forward_mha( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, kv_c_and_k_pe_cache_scale: torch.Tensor |None = None, ) -> torch.Tensor: # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None assert self.dcp_world_size != -1 prefill_metadata = attn_metadata.prefill use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype() # Convert q to FP8 if FP8 prefill attention is enabled if use_fp8_prefill: q = q.to(prefill_metadata.q_data_type) has_context = prefill_metadata.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v_nope = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) v = v_nope k[...,:self.qk_nope_head_dim] = k_nope attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) output = self._run_prefill_new_tokens( prefill=attn_metadata.prefill, q=q, k=k, v=v, return_softmax_lse=has_context, out=attn_output, ) if has_context: suffix_output, suffix_lse = output if self.dcp_world_size > 1: context_output, context_lse = ( self._context_parallel_compute_prefill_context( q, kv_c_and_k_pe_cache, attn_metadata, k_scale=None, dcp_world_size=self.dcp_world_size, ) ) else: context_output, context_lse = self._compute_prefill_context( \ q, kv_c_and_k_pe_cache, kv_c_and_k_pe_cache_scale, attn_metadata) output = torch.empty_like(suffix_output) output = merge_attn_states( output=output, prefix_output=context_output, prefix_lse=context_lse, suffix_output=suffix_output, suffix_lse=suffix_lse, ) # unpad if necessary if self._pad_v: output = output[..., : v.shape[-1]] return output @abstractmethod def forward_mqa( self, ql_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, k_c_normed: torch.Tensor | None, k_pe: torch.Tensor | None, kv_c_and_k_pe_cache_scale: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: raise NotImplementedError