# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Optional import numpy as np import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata) from vllm.attention.backends.utils import get_mla_dims from vllm_kunlun.ops.attention.flashmla import (flash_mla_sparse_prefill, flash_mla_with_kvcache, get_mla_metadata, kunlun_flash_mla_with_kvcache) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, reshape_attn_output_for_spec_decode, reshape_query_for_spec_decode, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.distributed import get_tp_group if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer logger = init_logger(__name__) """ NOTE: FlashMLA Sparse uses an fp8 cache with the following format In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: - **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. - **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on. - **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy. """ def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor: # Convert base-2 LSE to natural-log LSE # Keep FP32 for numerical stability during the merge. return (lse_base2.to(torch.float32) * math.log(2.0)) class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod def get_name() -> str: return "FLASHMLA_SPARSE_VLLM_V1" @staticmethod def get_metadata_cls() -> type[AttentionMetadata]: return FlashMLASparseMetadata @staticmethod def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: return FlashMLASparseMetadataBuilder @staticmethod def get_impl_cls() -> type["FlashMLASparseImpl"]: return FlashMLASparseImpl @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if cache_dtype_str == "fp8_ds_mla": # custom storage fromat is 656 bytes # see FlashMLA readme.md for details return (num_blocks, block_size, 656) else: return (num_blocks, block_size, head_size) @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @dataclass class MLASparsePrefillMetadata: # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because # the kernel is not from flashmla block_table: torch.Tensor = None has_context: bool = False context_lens: Optional[torch.Tensor] = None # Sequence lengths (context + query) for prefill requests # Shape: [num_prefill_reqs] seq_lens: torch.Tensor = None # Request ID for each token: -1 for decode tokens, request index # (0, 1, 2, ...) for prefill tokens. # Shape: [num_actual_tokens] request_ids: torch.Tensor = None query_start_loc: torch.Tensor = None query_start_loc_cpu: torch.Tensor = None @dataclass class FlashMLASparseDecodeAndContextMetadata: scheduler_metadata: torch.Tensor = None num_splits: torch.Tensor = None cache_lens: torch.Tensor = None prefill_context_lengths: Optional[torch.Tensor] = None prefill_new_k_start_locs: Optional[torch.Tensor] = None dummy_block_table: torch.Tensor = None seq_lens: torch.Tensor = None seq_lens_cpu: torch.Tensor = None max_seq_len: int = -1 # needed for reshape in spec decode def filter_prefill_indices( self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert self.prefill_context_lengths is not None prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1) context_indices = torch.where(indices < prefill_context_lengths, indices, -1) new_token_indices = torch.where(indices >= prefill_context_lengths, indices - prefill_context_lengths, -1) return context_indices, new_token_indices @dataclass class FlashMLASparseMetadata: num_reqs: int max_query_len: int max_seq_len: int num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor req_id_per_token: torch.Tensor block_size: int = 64 topk_tokens: int = 2048 num_prefills: int = 0 num_decodes: int = 0 num_prefill_tokens: int = 0 num_decode_tokens: int = 0 decode_metadata: Optional[FlashMLASparseDecodeAndContextMetadata] = None prefill_metadata: Optional[MLASparsePrefillMetadata] = None @dataclass class FP8KernelMetadata: scheduler_metadata: Optional[torch.Tensor] num_splits: torch.Tensor dummy_block_table: torch.Tensor cache_lens: torch.Tensor fp8_extra_metadata: Optional[FP8KernelMetadata] = None @triton.jit def _convert_req_index_to_global_index_kernel( req_id_ptr, # int32 [num_tokens] block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] # shapes (compile-time where possible) max_num_blocks_per_req: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, # tile width along columns # strides (in elements) bt_stride0, bt_stride1, ti_stride0, ti_stride1, out_stride0, out_stride1, ): # program_id(0) -> token_id (row) # program_id(1) -> tile index along columns token_id = tl.program_id(0) tile_id = tl.program_id(1) # Each program covers BLOCK_N consecutive columns indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) # Load request id for this token (no mask: grid is exact) req = tl.load(req_id_ptr + token_id) # Load token indices for this tile ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 tok = tl.load(ti_ptr) # int32 # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 # Compute block id and in-block offset block_id = tok // BLOCK_SIZE inblock_off = tok % BLOCK_SIZE # Guard block_table access valid_block = block_id < max_num_blocks_per_req bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 base = tl.load(bt_ptr, mask=valid_block, other=0) # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset out_val = tl.where(is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 tl.store(out_ptr_ij, out_val) def triton_convert_req_index_to_global_index( req_id: torch.Tensor, # int32 [num_tokens] block_table: torch. Tensor, # int32 [num_requests, max_num_blocks_per_req] token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] BLOCK_SIZE: int = 64, NUM_TOPK_TOKENS: int = 2048, BLOCK_N: int = 128, # tile width along columns ): """ out[token_id, indice_id] = block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE Only when token_indices[token_id, indice_id] == -1 do we output -1. For safety, we also output -1 if the derived block_id would be out-of-bounds. """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS assert NUM_TOPK_TOKENS % BLOCK_N == 0, \ f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \ f"BLOCK_N ({BLOCK_N})" num_tokens = req_id.shape[0] num_requests, max_num_blocks_per_req = block_table.shape tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N # Ensure contiguous tensors on the same device req_id_c = req_id.contiguous() block_table_c = block_table.contiguous() token_indices_c = token_indices.contiguous() out = torch.empty_like(token_indices_c) # Strides in elements bt_stride0, bt_stride1 = block_table_c.stride() ti_stride0, ti_stride1 = token_indices_c.stride() out_stride0, out_stride1 = out.stride() # Exact 2D grid: tokens × column tiles grid = (num_tokens, tiles_per_row) _convert_req_index_to_global_index_kernel[grid]( req_id_c, block_table_c, token_indices_c, out, # shapes / constexprs max_num_blocks_per_req, BLOCK_SIZE, BLOCK_N, # strides bt_stride0, bt_stride1, ti_stride0, ti_stride1, out_stride0, out_stride1, ) return out def kunlun_convert_req_index_to_global_index( req_id: torch.Tensor, # int32 [num_tokens] block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] BLOCK_SIZE: int = 64, NUM_TOPK_TOKENS: int = 2048, ): assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS num_tokens = req_id.shape[0] num_requests, max_num_blocks_per_req = block_table.shape out = torch.zeros_like(token_indices) # Compute block_id and inblock_off for all tokens at once block_id = token_indices // BLOCK_SIZE inblock_off = token_indices % BLOCK_SIZE # Create mask for invalid tokens (tok < 0) invalid_tok_mask = token_indices < 0 # Create mask for out-of-bounds block_id oob_block_mask = block_id >= max_num_blocks_per_req # Combine masks - output -1 for either condition invalid_mask = invalid_tok_mask | oob_block_mask # Get request IDs expanded to match token_indices shape req_ids_expanded = req_id.unsqueeze(1).expand(-1, NUM_TOPK_TOKENS) # Gather base addresses from block_table # Clamp block_id to avoid index errors (we'll mask these out anyway) block_id_clamped = torch.clamp(block_id, 0, max_num_blocks_per_req - 1) # Use advanced indexing to get base addresses base_addrs = block_table[req_ids_expanded, block_id_clamped] # Compute the global indices global_indices = base_addrs * BLOCK_SIZE + inblock_off # Apply mask: set invalid positions to -1 out = torch.where(invalid_mask, torch.tensor(-1, dtype=torch.int32, device=token_indices.device), global_indices) return out def kunlun_concat_and_cache_mla( kv_c: torch.Tensor, #[num_tokens, kv_lora_rank] k_pe: torch.Tensor, #[num_tokens, pe_dim] kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)] slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens] kv_cache_dtype: str, scale: torch.Tensor ): num_tokens = slot_mapping.shape[0] kv_lora_rank = kv_c.shape[1] pe_dim = k_pe.shape[1] block_size = kv_cache.shape[1] def kunlun_fp8_ds_mla(): for token_idx in range(num_tokens): slot = slot_mapping[token_idx].item() if slot < 0: continue block_idx = slot // block_size block_offset = slot % block_size kv_c_i = kv_c[token_idx].view(4,kv_lora_rank//4).contiguous() kv_c_i_int8 = torch.zeros( kv_c_i.shape, device=kv_c.device, dtype=torch.int8, ) kv_c_i_scale = torch.zeros( [kv_c_i.shape[0], 1], device=kv_c.device, dtype=torch.float32, ) torch.ops._C.quant2d(kv_c_i, kv_c_i_int8, kv_c_i_scale, force_sdnn=True) kv_c_i_scale /= 127 kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c_i_int8.view(-1).view(torch.uint8).contiguous() kv_cache[block_idx, block_offset, kv_lora_rank:kv_lora_rank + 16] = kv_c_i_scale.view(-1).view(torch.uint8).contiguous() kv_cache[block_idx, block_offset, kv_lora_rank+16:] = k_pe[token_idx, :].view(torch.uint8).contiguous() def kunlun_mla(): for token_idx in range(num_tokens): slot = slot_mapping[token_idx].item() if slot < 0: continue block_idx = slot // block_size block_offset = slot % block_size kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[token_idx, :].contiguous() kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[token_idx, :].contiguous() if (kv_cache_dtype == "fp8_ds_mla"): assert kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla" assert pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla" assert kv_cache.shape[2] == 656 // kv_cache.element_size(), "kv_cache.shape[2] must be 656 bytes for fp8_ds_mla" assert kv_c.element_size() == 2, "kv_c.element_size() must be 2 for fp8_ds_mla" assert k_pe.element_size() == 2, "k_pe.element_size() must be 2 for fp8_ds_mla" kunlun_fp8_ds_mla() else: assert kv_cache.shape[2] == kv_lora_rank + pe_dim kunlun_mla() @dataclass class FlashMLASparseMetadataBuilder( AttentionMetadataBuilder[FlashMLASparseMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.layer_names = layer_names cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device # Treat requests with query length <= 1 as decodes to match the # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) # 从最新版本vllm中引入的 self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" self.topk_tokens_tensor = torch.tensor([self.topk_tokens], device=device, dtype=torch.int32) # self.max_model_len_tensor = torch.tensor( # [self.model_config.max_model_len], # device=device, # dtype=torch.int32) # this is ignored by `flash_mla_with_kvcache` if indices not None self.dummy_block_table = torch.empty((1, 1), dtype=torch.int32, device=self.device) # Equation taken from FlashMLA/csrc/pybind.cpp h_q, h_k = self.num_heads, 1 s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest max_num_sm_parts = int( max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)) if current_platform.is_device_capability(100): max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.zeros( # TileSchedulerMetaDataSize = 8 # see: FlashMLA/csrc/params.h (max_num_sm_parts, 8), dtype=torch.int32, device=device) self.num_splits_buffer = torch.zeros( # We pack all the tokens into one batch for sparse attention. # Otherwise, we can exceed the sm of `get_mla_metadata`. ( 2, ), dtype=torch.int32, device=device) self.req_id_per_token_buffer = torch.zeros( (vllm_config.scheduler_config.max_num_batched_tokens, ), dtype=torch.int32, device=device) def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> FlashMLASparseMetadata: num_tokens = common_attn_metadata.num_actual_tokens starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) seg_lengths = np.diff(starts) req_id_per_token = np.repeat( np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\ .copy_(torch.from_numpy(req_id_per_token), non_blocking=True) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] fp8_extra_metadata = None if self.use_fp8_kv_cache: cache_seqlens_cpu, cache_seqlens = get_mla_metadata( cache_seqlens=self.topk_tokens_tensor, ) fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( scheduler_metadata=None, num_splits=None, # cache_lens and block_table are basically unused in sparse case # but the decode kernel will treat -1 and indices >= cache_lens # as invalid so we make sure cache_lens is large enough to not # accidentally mark indices invalid, we will use -1 exclusively # to mark invalid indices cache_lens=cache_seqlens_cpu, dummy_block_table=self.dummy_block_table) (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold or 1, require_uniform=True, ) ) # For pure decode batches, prefill_request_id will be None # For mixed batches, it will have -1 for decode and request_id for prefill prefill_metadata = None if num_prefills > 0: prefill_metadata = MLASparsePrefillMetadata( query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值 query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes], ) decode_metadata = None if num_decodes > 0: max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_decodes].max()) decode_metadata = FlashMLASparseDecodeAndContextMetadata( max_seq_len=max_seq_len, seq_lens=common_attn_metadata.seq_lens[:num_decodes], seq_lens_cpu=common_attn_metadata.seq_lens_cpu[:num_decodes], ) metadata = FlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, max_seq_len=common_attn_metadata.max_seq_len, num_actual_tokens=common_attn_metadata.num_actual_tokens, query_start_loc=common_attn_metadata.query_start_loc, slot_mapping=common_attn_metadata.slot_mapping, block_table=common_attn_metadata.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, fp8_extra_metadata=fp8_extra_metadata, num_prefills=num_prefills, num_decodes=num_decodes, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, decode_metadata=decode_metadata, prefill_metadata=prefill_metadata ) return metadata class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments topk_indice_buffer: Optional[torch.Tensor] = None, indexer: Optional["Indexer"] = None, **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer self.padding = 128 if current_platform.is_device_capability( 100) else 64 def _forward_bf16_kv( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.contiguous().view( -1, kv_c_and_k_pe_cache.shape[-1]) # num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens num_decodes = attn_metadata.num_decodes has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens def _bf16_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: # Reshape q: (num_decode_tokens, num_heads, head_dim) # -> (num_decodes, seq_len, num_heads, head_dim) q = reshape_query_for_spec_decode(q, num_decodes) seq_len = q.shape[1] # Reshape topk_indices: (num_decode_tokens, topk) # -> (num_decodes, seq_len, topk) topk_indices = topk_indices.view(num_decodes, seq_len, -1) decode_metadata = attn_metadata.decode_metadata _attn_out, _, _ = kunlun_flash_mla_with_kvcache( q=q, k_cache=kv_c_and_k_pe_cache, head_dim_v=512, cache_seqlens=decode_metadata.seq_lens, cache_seqlens_cpu=decode_metadata.seq_lens_cpu, is_fp8_kvcache=False, indices=topk_indices, softmax_scale=self.softmax_scale, max_seq_kv=decode_metadata.max_seq_len ) # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v) # -> (num_decode_tokens, num_heads, head_dim_v) return reshape_attn_output_for_spec_decode(_attn_out) def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: prefill_metadata = attn_metadata.prefill_metadata topk_indices = topk_indices.view(num_prefill_tokens, 1, -1) # NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的 _attn_out = flash_mla_sparse_prefill( q=q, kv=kv_c_and_k_pe_cache, indices=topk_indices, sm_scale=self.softmax_scale, q_lod_xpu=prefill_metadata.query_start_loc, q_lod_cpu=prefill_metadata.query_start_loc_cpu )[0] return _attn_out topk_indices_global = torch.ops.xspeedgate_ops.convert_req_index_to_global_index( req_id=attn_metadata.req_id_per_token, block_table=attn_metadata.block_table, token_indices=topk_indices, block_size=attn_metadata.block_size, num_topk_tokens=attn_metadata.topk_tokens, ) attn_out = torch.empty( (num_tokens, self.num_heads, self.kv_lora_rank), dtype=q.dtype, device=q.device, ) if has_prefill: prefill_q = q[num_decode_tokens:] prefill_topk_indices_global = topk_indices_global[num_decode_tokens:] attn_out[num_decode_tokens:] = _bf16_prefill(prefill_q, prefill_topk_indices_global) # 处理decode部分 - 需要正确的block table映射print if has_decode: decode_q = q[:num_decode_tokens] decode_topk_indices_global = topk_indices_global[:num_decode_tokens] attn_out[:num_decode_tokens] = _bf16_decode(decode_q, decode_topk_indices_global) return attn_out def _forward_fp8_kv(self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: # TODO: When fwd_kvcache_mla supports uint8 kv cache, execute this function. assert attn_metadata.fp8_extra_metadata is not None extra_metadata = attn_metadata.fp8_extra_metadata _attn_out, _ = flash_mla_with_kvcache( q=q.unsqueeze(0), # unsqueeze to add batch_dim k_cache=kv_c_and_k_pe_cache, block_table=extra_metadata.dummy_block_table, head_dim_v=512, cache_seqlens=extra_metadata.cache_lens, tile_scheduler_metadata=extra_metadata.scheduler_metadata, # None num_splits=extra_metadata.num_splits, # None is_fp8_kvcache=True, indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim softmax_scale=self.softmax_scale, max_seq_kv=attn_metadata.max_seq_len ) return _attn_out def forward( self, layer: AttentionLayer, q: torch.Tensor, k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # MQA 576/512 approach for both prefill and decode assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLACommonImpl") if attn_metadata is None: # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) ql_nope = torch.bmm(q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) topk_indices = self.topk_indices_buffer[:num_actual_toks] q = torch.cat([ql_nope, q_pe], dim=-1) # write the latent and rope to kv cache if kv_cache.numel() > 0: torch.ops._C.concat_and_cache_mla( kv_c=k_c_normed, k_pe=k_pe.squeeze(1), kv_cache=kv_cache, slot_mapping=attn_metadata.slot_mapping.flatten(), ) if self.kv_cache_dtype != "fp8_ds_mla": attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) else: # attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, # attn_metadata) raise NotImplementedError self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output