# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Optional import numpy as np import torch from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, MultipleOf, ) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.flashmla import ( flash_mla_sparse_prefill, flash_mla_with_kvcache, get_mla_metadata, ) from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, reshape_attn_output_for_spec_decode, reshape_query_for_spec_decode, split_decodes_and_prefills, split_prefill_chunks, ) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer logger = init_logger(__name__) # For FP8 sparse attention we have two impelementations: # 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is # done by treating all tokens as single batch. # 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill # (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using # the FP8 decode kernel for decode. # Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16 # prefill kernel requires padding the numer of heads to 128 while the decode does not # so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed # batch mode (#2). MIN_HEADS_FOR_BF16_PREFILL = 32 """ NOTE: FlashMLA Sparse uses an fp8 cache with the following format In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: - **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. - **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on. - **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy. """ class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [64] @staticmethod def get_name() -> str: return "FLASHMLA_SPARSE" @staticmethod def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: return FlashMLASparseMetadataBuilder @staticmethod def get_impl_cls() -> type["FlashMLASparseImpl"]: return FlashMLASparseImpl @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @classmethod def is_mla(cls) -> bool: return True @classmethod def is_sparse(cls) -> bool: return True @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability.major in [9, 10] @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if cache_dtype_str == "fp8_ds_mla": # custom storage fromat is 656 bytes # see FlashMLA readme.md for details return (num_blocks, block_size, 656) else: return (num_blocks, block_size, head_size) @dataclass class FlashMLASparseMetadata: num_reqs: int max_query_len: int max_seq_len: int num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor req_id_per_token: torch.Tensor block_size: int = 64 topk_tokens: int = 2048 @dataclass class FP8KernelMetadata: scheduler_metadata: torch.Tensor | None num_splits: torch.Tensor dummy_block_table: torch.Tensor cache_lens: torch.Tensor @dataclass class FP8SeperatePrefillDecode: @dataclass class Decode: kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata" decode_query_len: int # needed for reshape in spec decode @dataclass class Prefill: # Sequence lengths (context + query) for prefill requests # Shape: [num_prefill_reqs] seq_lens: torch.Tensor # Request ID for each token: -1 for decode tokens, request index # (0, 1, 2, ...) for prefill tokens. # Shape: [num_actual_tokens] request_ids: torch.Tensor # Workspace start offsets for all prefill requests # Shape: [num_prefill_reqs], adjusted in-place per chunk to be # 0-indexed within each chunk. Used to map prefill tokens to workspace # offsets in convert_logical_index_to_physical_index workspace_starts: torch.Tensor @dataclass class Chunk: """Metadata for a chunk of prefill requests. Prefill requests may be chunked to fit within the fixed workspace size. """ seq_lens: torch.Tensor tokens_slice: slice block_table: torch.Tensor req_start_idx: int workspace_starts: torch.Tensor chunk_tot_seqlen: int chunks: list[Chunk] num_prefills: int = 0 num_decodes: int = 0 num_prefill_tokens: int = 0 num_decode_tokens: int = 0 decode: Decode | None = None prefill: Prefill | None = None fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None fp8_use_mixed_batch: bool = False # Kernel with prefill workspace support @triton.jit def _convert_req_index_to_global_index_kernel( req_id_ptr, # int32 [num_tokens] block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr # shapes (compile-time where possible) max_num_blocks_per_req: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, # tile width along columns HAS_PREFILL: tl.constexpr, # strides (in elements) bt_stride0, bt_stride1, ti_stride0, ti_stride1, out_stride0, out_stride1, ): # program_id(0) -> token_id (row) # program_id(1) -> tile index along columns token_id = tl.program_id(0) tile_id = tl.program_id(1) # Each program covers BLOCK_N consecutive columns indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) # Load request id for this token (no mask: grid is exact) req = tl.load(req_id_ptr + token_id) # Load token indices for this tile ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 tok = tl.load(ti_ptr) # int32 # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 is_prefill = False if HAS_PREFILL: prefill_req_id = tl.load(prefill_request_id_ptr + token_id) is_prefill = prefill_req_id >= 0 # Compute block id and in-block offset block_id = tok // BLOCK_SIZE inblock_off = tok % BLOCK_SIZE # Guard block_table access valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 is_invalid_tok |= ~valid_block base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) out_val = base * BLOCK_SIZE + inblock_off # Override with prefill output if prefill is enabled if HAS_PREFILL: workspace_start = tl.load( workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 ) prefill_out = workspace_start + tok out_val = tl.where(is_prefill, prefill_out, out_val) out_val = tl.where(is_invalid_tok, -1, out_val) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 tl.store(out_ptr_ij, out_val) def triton_convert_req_index_to_global_index( req_id: torch.Tensor, # int32 [num_tokens] block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] BLOCK_SIZE: int = 64, NUM_TOPK_TOKENS: int = 2048, BLOCK_N: int = 128, # tile width along columns HAS_PREFILL_WORKSPACE: bool = False, prefill_workspace_request_ids: torch.Tensor | None = None, prefill_workspace_starts: torch.Tensor | None = None, ): """ out[token_id, indice_id] = block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE Only when token_indices[token_id, indice_id] == -1 do we output -1. For safety, we also output -1 if the derived block_id would be out-of-bounds. When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets instead of global cache slots. prefill_workspace_request_ids and prefill_workspace_starts must be provided. prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else prefill request index (maps to prefill_workspace_starts) prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace starts for each prefill request """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" ) if HAS_PREFILL_WORKSPACE: assert prefill_workspace_request_ids is not None assert prefill_workspace_starts is not None assert prefill_workspace_request_ids.dtype == torch.int32 assert prefill_workspace_starts.dtype == torch.int32 num_tokens = req_id.shape[0] max_num_blocks_per_req = block_table.shape[1] tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N # Ensure contiguous tensors on the same device req_id_c = req_id.contiguous() block_table_c = block_table.contiguous() token_indices_c = token_indices.contiguous() out = torch.empty_like(token_indices_c) # Strides in elements bt_stride0, bt_stride1 = block_table_c.stride() ti_stride0, ti_stride1 = token_indices_c.stride() out_stride0, out_stride1 = out.stride() # Prepare prefill pointers if HAS_PREFILL_WORKSPACE: assert prefill_workspace_request_ids is not None # for mypy assert prefill_workspace_starts is not None # for mypy assert prefill_workspace_request_ids.is_contiguous() assert prefill_workspace_starts.is_contiguous() # Exact 2D grid: tokens × column tiles grid = (num_tokens, tiles_per_row) _convert_req_index_to_global_index_kernel[grid]( req_id_c, block_table_c, token_indices_c, out, prefill_workspace_request_ids, prefill_workspace_starts, # shapes / constexprs max_num_blocks_per_req, BLOCK_SIZE, BLOCK_N, HAS_PREFILL_WORKSPACE, # strides bt_stride0, bt_stride1, ti_stride0, ti_stride1, out_stride0, out_stride1, ) return out def get_prefill_workspace_size(max_model_len: int): # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. # May be tuned later. # Memory usage: 5 * max_model_len * 576 * 2 bytes # Example: DeepSeek-V3.2 with max_model_len=163840 -> # 5 * 163840 * 576 * 2 = ~900 MB # This fits nicely below the typical MoE workspace size of >2GB so this is "free" return max_model_len * 5 class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ) -> None: self.vllm_config = vllm_config self.layer_names = layer_names cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device # Treat requests with query length <= 1 as decodes to match the # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" max_num_seqs = vllm_config.scheduler_config.max_num_seqs # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG) self.topk_tokens_tensor = torch.full( (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32 ) # Shape: [max_num_seqs], all elements = max_model_len self.max_model_len_tensor = torch.full( (max_num_seqs,), self.model_config.max_model_len, device=device, dtype=torch.int32, ) # this is ignored by `flash_mla_with_kvcache` if indices not None self.dummy_block_table = torch.empty( (max_num_seqs, 1), dtype=torch.int32, device=self.device ) # Equation taken from FlashMLA/csrc/pybind.cpp h_q, h_k = self.num_heads, 1 s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest max_num_sm_parts = int( max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) ) if current_platform.is_device_capability_family(100): max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.empty( # TileSchedulerMetaDataSize = 8 # see: FlashMLA/csrc/params.h (max_num_sm_parts, 8), dtype=torch.int32, device=device, ) # Sized for per-request batching (num_decodes + 1) self.num_splits_buffer = torch.empty( (max_num_seqs + 1,), dtype=torch.int32, device=device, ) self.req_id_per_token_buffer = torch.empty( (vllm_config.scheduler_config.max_num_batched_tokens,), dtype=torch.int32, device=device, ) def _build_fp8_mixed_decode_prefill( self, common_attn_metadata: CommonAttentionMetadata, ) -> "FlashMLASparseMetadata.FP8KernelMetadata": """Build FP8 metadata treating all tokens as one mixed batch. This matches main branch's approach and avoids the BF16 prefill kernel which has head padding overhead when num_heads is small (high TP case). """ num_tokens = common_attn_metadata.num_actual_tokens # Build metadata for all tokens as a single batch tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens=self.topk_tokens_tensor[:1], # Single batch num_q_tokens_per_head_k=num_tokens * self.num_heads, topk=self.topk_tokens, num_heads_q=self.num_heads, num_heads_k=1, is_fp8_kvcache=True, ) num_sm_parts = tile_scheduler_metadata.size(0) tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ :num_sm_parts ] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) num_splits_view = self.num_splits_buffer[:2] num_splits_view.copy_(num_splits) fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata( scheduler_metadata=tile_scheduler_metadata_buffer, num_splits=num_splits_view, cache_lens=self.max_model_len_tensor[:1], dummy_block_table=self.dummy_block_table[:1], ) return fp8_metadata def _build_fp8_separate_prefill_decode( self, common_attn_metadata: CommonAttentionMetadata, ) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode": num_tokens = common_attn_metadata.num_actual_tokens (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold or 1, require_uniform=True, ) ) FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode fp8_metadata = FP8Meta( num_decodes=num_decodes, num_prefills=num_prefills, num_decode_tokens=num_decode_tokens, num_prefill_tokens=num_prefill_tokens, ) # Extract prefill sequence lengths (context + query, not just query) # Decode requests come first in the batch, prefill requests follow prefill_seq_lens = None prefill_request_id = None prefill_workspace_starts = None prefill_chunks = None # For pure decode batches, prefill_request_id will be None # For mixed batches, it will have -1 for decode and request_id for prefill if num_prefills > 0: seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens = common_attn_metadata.seq_lens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:] prefill_seq_lens = seq_lens[num_decodes:] # Build prefill_request_id: -1 for decode, request index for # prefill. This enables a single # convert_logical_index_to_physical_index call for all tokens prefill_request_id = torch.full( (num_tokens,), -1, dtype=torch.int32, device=self.device ) # Map prefill tokens to their request IDs (0, 1, 2, ...) for req_idx in range(num_prefills): # Get query token range for this prefill request global_req_idx = num_decodes + req_idx req_query_start = query_start_loc_cpu[global_req_idx] req_query_end = query_start_loc_cpu[global_req_idx + 1] prefill_request_id[req_query_start:req_query_end] = req_idx # will be adjusted by chunk loop prefill_workspace_starts_cpu = torch.zeros( num_prefills, dtype=torch.int32, pin_memory=True ) prefill_workspace_starts_cpu[1:] = torch.cumsum( prefill_seq_lens_cpu[:-1], dim=0 ) # populated by non-blocking copy after prefill_workspace_starts_cpu is # updated by each chunk prefill_workspace_starts = torch.empty( num_prefills, dtype=torch.int32, device=self.device ) # Chunk prefill requests to fit within workspace size max_prefill_buffer_size = get_prefill_workspace_size( self.vllm_config.model_config.max_model_len ) chunk_bounds = split_prefill_chunks( prefill_seq_lens_cpu, max_prefill_buffer_size ) prefill_chunks = [] for chunk_start, chunk_end in chunk_bounds: # Adjust workspace_starts in-place per chunk to be # 0-indexed within each chunk # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]] # Initial: workspace_starts=[0,10,25,45] # After: workspace_starts=[0,10,0,20] # (chunk 0 starts at 0, chunk 1 starts at 0) offset = prefill_workspace_starts_cpu[chunk_start].item() prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum() token_start = query_start_loc_cpu[num_decodes + chunk_start].item() token_end = query_start_loc_cpu[num_decodes + chunk_end].item() tokens_slice = slice(token_start, token_end) # Create chunk view of gpu tensor chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end] chunk_block_table = common_attn_metadata.block_table_tensor[ num_decodes + chunk_start : num_decodes + chunk_end ] prefill_chunks.append( FP8Meta.Prefill.Chunk( seq_lens=chunk_seq_lens, tokens_slice=tokens_slice, block_table=chunk_block_table, req_start_idx=chunk_start, workspace_starts=chunk_workspace_starts, chunk_tot_seqlen=chunk_tot_seqlen, ) ) prefill_workspace_starts.copy_( prefill_workspace_starts_cpu, non_blocking=True ) fp8_metadata.prefill = FP8Meta.Prefill( seq_lens=prefill_seq_lens, request_ids=prefill_request_id, workspace_starts=prefill_workspace_starts, chunks=prefill_chunks, ) if num_decodes > 0: # Compute decode_query_len for spec decode (uniform due to require_uniform) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item() tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens=self.topk_tokens_tensor[:num_decodes], num_q_tokens_per_head_k=decode_query_len * self.num_heads, topk=self.topk_tokens, num_heads_q=self.num_heads, num_heads_k=1, is_fp8_kvcache=True, ) num_sm_parts = tile_scheduler_metadata.size(0) # Copy to persistent buffer for full-CG support tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ :num_sm_parts ] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) # num_splits has size [num_decodes + 1] num_splits_view = self.num_splits_buffer[: num_decodes + 1] num_splits_view.copy_(num_splits) kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata( scheduler_metadata=tile_scheduler_metadata_buffer, num_splits=num_splits_view, dummy_block_table=self.dummy_block_table[:num_decodes], cache_lens=self.max_model_len_tensor[:num_decodes], ) fp8_metadata.decode = FP8Meta.Decode( kernel_metadata=kernel_meta, decode_query_len=decode_query_len, ) return fp8_metadata def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> FlashMLASparseMetadata: cm = common_attn_metadata num_tokens = cm.num_actual_tokens starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32) seg_lengths = np.diff(starts) req_id_per_token = np.repeat( np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths ) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( torch.from_numpy(req_id_per_token), non_blocking=True ) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] fp8_extra_metadata: ( FlashMLASparseMetadata.FP8SeperatePrefillDecode | FlashMLASparseMetadata.FP8KernelMetadata | None ) = None fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL if self.use_fp8_kv_cache: if fp8_use_mixed_batch: fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm) else: fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm) metadata = FlashMLASparseMetadata( num_reqs=cm.num_reqs, max_query_len=cm.max_query_len, max_seq_len=cm.max_seq_len, num_actual_tokens=cm.num_actual_tokens, query_start_loc=cm.query_start_loc, slot_mapping=cm.slot_mapping, block_table=cm.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, fp8_extra_metadata=fp8_extra_metadata, fp8_use_mixed_batch=fp8_use_mixed_batch, ) return metadata class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None, attn_type: str, kv_sharing_target_layer_name: str | None, # MLA Specific Arguments topk_indice_buffer: torch.Tensor | None = None, indexer: Optional["Indexer"] = None, **mla_args, ) -> None: super().__init__( num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args, ) self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer self.padding = 128 if current_platform.is_device_capability_family(100) else 64 if kv_cache_dtype == "fp8_ds_mla": # Reserve workspace during initialization vllm_config = get_current_vllm_config() assert vllm_config is not None and vllm_config.model_config is not None prefill_workspace_size = get_prefill_workspace_size( vllm_config.model_config.max_model_len ) self.prefill_workspace_shape = (prefill_workspace_size, head_size) (self.prefill_bf16_workspace,) = ( current_workspace_manager().get_simultaneous( (self.prefill_workspace_shape, torch.bfloat16) ) ) def _forward_bf16_kv( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: FlashMLASparseMetadata, ) -> torch.Tensor: # Convert per-request indices to global slots (decode) or workspace # offsets (prefill). topk_indices = triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, topk_indices, BLOCK_SIZE=attn_metadata.block_size, NUM_TOPK_TOKENS=topk_indices.shape[1], ) return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices) def _forward_fp8_kv_separate_prefill_decode( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: FlashMLASparseMetadata, ) -> torch.Tensor: fp8_metadata = attn_metadata.fp8_extra_metadata assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) num_decodes = fp8_metadata.num_decodes prefill_request_ids = None prefill_workspace_starts = None has_prefill_workspace = False if fp8_metadata.prefill is not None: prefill_request_ids = fp8_metadata.prefill.request_ids prefill_workspace_starts = fp8_metadata.prefill.workspace_starts has_prefill_workspace = True # Convert per-request indices to global slots (decode) or workspace # offsets (prefill). # For FP8 cache: prefill uses workspace mapping (upconverted to BF16) # For BF16 cache: always use global cache slots (no workspace) # prefill_workspace_starts has been adjusted in-place per chunk so # prefill indices automatically come out chunk-local topk_indices = triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, topk_indices, BLOCK_SIZE=attn_metadata.block_size, NUM_TOPK_TOKENS=topk_indices.shape[1], HAS_PREFILL_WORKSPACE=has_prefill_workspace, prefill_workspace_request_ids=prefill_request_ids, prefill_workspace_starts=prefill_workspace_starts, ) fp8_metadata = attn_metadata.fp8_extra_metadata assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: # Reshape q: (num_decode_tokens, num_heads, head_dim) # -> (num_decodes, seq_len, num_heads, head_dim) q = reshape_query_for_spec_decode(q, num_decodes) seq_len = q.shape[1] # Reshape topk_indices: (num_decode_tokens, topk) # -> (num_decodes, seq_len, topk) topk_indices = topk_indices.view(num_decodes, seq_len, -1) assert fp8_metadata.decode is not None attn_out, _ = self._fp8_flash_mla_kernel( q=q, kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, topk_indices=topk_indices, kernel_metadata=fp8_metadata.decode.kernel_metadata, ) # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v) # -> (num_decode_tokens, num_heads, head_dim_v) return reshape_attn_output_for_spec_decode(attn_out) num_decode_tokens = fp8_metadata.num_decode_tokens num_prefill_tokens = fp8_metadata.num_prefill_tokens # Pure decode: direct call without allocation if num_decode_tokens > 0 and num_prefill_tokens == 0: assert fp8_metadata.decode is not None attn_out = _fp8_decode(q, topk_indices) else: # Mixed or pure prefill: allocate output tensor attn_out = q.new_empty( (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank), dtype=q.dtype, device=q.device, ) if num_decode_tokens > 0: attn_out[:num_decode_tokens] = _fp8_decode( q[:num_decode_tokens], topk_indices[:num_decode_tokens] ) assert fp8_metadata.prefill is not None for chunk in fp8_metadata.prefill.chunks: chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen] ops.cp_gather_and_upconvert_fp8_kv_cache( kv_c_and_k_pe_cache, chunk_workspace, chunk.block_table, chunk.seq_lens, chunk.workspace_starts, len(chunk.block_table), ) chunk_q = q[chunk.tokens_slice] chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice] attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel( chunk_q, chunk_workspace, chunk_topk_indices_workspace, ) return attn_out def _forward_fp8_kv_mixed_batch( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: FlashMLASparseMetadata, ) -> torch.Tensor: """Mixed batch FP8 forward path that treats all tokens as one batch. This is equivalent to main branch's approach and avoids the BF16 prefill kernel which has head padding overhead when num_heads is small. Used when use_mixed_batch is True. """ # Convert per-request indices to global slots (decode) or workspace # offsets (prefill). topk_indices = triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, topk_indices, BLOCK_SIZE=attn_metadata.block_size, NUM_TOPK_TOKENS=topk_indices.shape[1], ) assert attn_metadata.fp8_extra_metadata is not None assert isinstance( attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata ) fp8_metadata = attn_metadata.fp8_extra_metadata _attn_out, _ = self._fp8_flash_mla_kernel( q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D) kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk) kernel_metadata=fp8_metadata, ) # Output is (1, T, H, D_v), squeeze back to (T, H, D_v) return _attn_out.squeeze(0) def _fp8_flash_mla_kernel( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata, ) -> torch.Tensor: return flash_mla_with_kvcache( q=q, k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), block_table=kernel_metadata.dummy_block_table, head_dim_v=512, cache_seqlens=kernel_metadata.cache_lens, tile_scheduler_metadata=kernel_metadata.scheduler_metadata, num_splits=kernel_metadata.num_splits, is_fp8_kvcache=True, indices=topk_indices, softmax_scale=self.softmax_scale, ) def _bf16_flash_mla_kernel( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( -1, 1, kv_c_and_k_pe_cache.shape[-1] ) # NOTE(Chen): kernel requires num_local_head to be a multiple of # 64 on hopper and 128 on blackwell if self.num_heads % self.padding != 0: assert self.padding % self.num_heads == 0 logger.warning_once( f"padding num_heads to {self.padding} \ due to sparse attn kernel requirement" ) q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) q_padded[:, : self.num_heads, :] = q q = q_padded topk_indices = topk_indices.view(num_tokens, 1, -1) output = flash_mla_sparse_prefill( q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale )[0] output = output[:, : self.num_heads, :] return output def forward( self, layer: AttentionLayer, q: torch.Tensor, k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata | None, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # MQA 576/512 approach for both prefill and decode assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported for MLACommonImpl" ) if attn_metadata is None: # Dummy run - no need to allocate buffers # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] topk_indices = self.topk_indices_buffer[:num_actual_toks] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) ql_nope = torch.bmm(q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" q = torch.cat([ql_nope, q_pe], dim=-1) # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( k_c_normed, k_pe.squeeze(1), kv_cache, attn_metadata.slot_mapping.flatten(), kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) if not use_fp8_cache: attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) elif attn_metadata.fp8_use_mixed_batch: attn_out = self._forward_fp8_kv_mixed_batch( q, kv_cache, topk_indices, attn_metadata ) else: attn_out = self._forward_fp8_kv_separate_prefill_decode( q, kv_cache, topk_indices, attn_metadata ) self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output