################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ import itertools from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Optional, Tuple import numpy as np import torch import torch_br from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata from vllm.attention.ops.flashmla import get_mla_metadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.mla.flashmla_sparse import ( FlashMLASparseBackend, FlashMLASparseImpl, FlashMLASparseMetadata, FlashMLASparseMetadataBuilder) from vllm.v1.attention.backends.utils import (AttentionCGSupport, CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) _NO_DEFAULT = object() @dataclass class SupaFlashMLASparseMetadata(FlashMLASparseMetadata): # BIREN Attention Params seq_start_loc: torch.Tensor = _NO_DEFAULT context_lens: torch.Tensor = _NO_DEFAULT max_decode_seq_len: int = -1 num_prefills: int = -1 num_decodes: int = -1 num_prefill_tokens: int = -1 num_decode_tokens: int = -1 def __post_init__(self): if self.seq_start_loc is _NO_DEFAULT or self.context_lens is _NO_DEFAULT or \ self.max_decode_seq_len == -1 or self.num_prefills == -1 or \ self.num_decodes == -1 or self.num_prefill_tokens == -1 or \ self.num_decode_tokens == -1: raise TypeError("__init__ missing required argument") class SupaFlashMLASparseMetadataBuilder(FlashMLASparseMetadataBuilder): reorder_batch_threshold: int = 1 cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__( kv_cache_spec=kv_cache_spec, layer_names=layer_names, vllm_config=vllm_config, device=device, ) self.vllm_config = vllm_config self.num_speculative_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens if self.vllm_config.speculative_config else 0) # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: """On SUPA, we want prefill at front and decode at back. """ # TODO update doc # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests # where attention is likely memory-bound and "prefill" to mean requests # where attention is likely compute-bound, TODO(lucas): figure out a # better naming here) decodes = [] prefills = [] num_decode_tokens = 0 num_prefill_tokens = 0 for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_spec_tokens = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) # for now treat 1 scheduled token as "decode" even if its not, # we should update this to something like < 8 in the future but # currently the TritonMLA._forward_decode only supports # num_tokens = 1 if num_tokens - num_spec_tokens == 1: decodes.append(i) num_decode_tokens += num_tokens else: prefills.append(i) num_prefill_tokens += num_tokens # TODO update doc # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are # relatively stationary (and new request are generally appended to the # persistent batch so already should be at the back) # To achieve this we loop over the decodes in descending order and # the prefills in ascending order. We swap decodes from the "back" # i.e. past where the last decode should be in the reodorered with # prefills from the front of the batch. # `decodes` and `prefills` are already in ascending order just based on # the above loop num_decodes = len(decodes) num_prefills = len(prefills) modified_batch = False # for i in range(1, min(num_decodes, num_prefills) + 1): # # If the decode is at the "back" of the batch, i, we can swap it # # with the prefill closest to the front of the batch # decode_idx = decodes[num_decodes - i] # if decode_idx < num_decodes: # break # input_batch.swap_states(prefills[i - 1], decode_idx) # modified_batch = True for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch prefills_idx = prefills[num_prefills - i] if prefills_idx < num_prefills: break input_batch.swap_states(decodes[i - 1], prefills_idx) modified_batch = True # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this self._num_decodes = num_decodes self._num_prefills = num_prefills self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens return modified_batch def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> SupaFlashMLASparseMetadata: num_tokens = common_attn_metadata.num_actual_tokens starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) seg_lengths = np.diff(starts) req_id_per_token = np.repeat( np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\ .copy_(torch.from_numpy(req_id_per_token), non_blocking=True) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] fp8_extra_metadata = None if self.use_fp8_kv_cache: tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens=self.topk_tokens_tensor, num_q_tokens_per_head_k=num_tokens * self.num_heads, topk=self.topk_tokens, num_heads_q=self.num_heads, num_heads_k=1, is_fp8_kvcache=True, ) num_sm_parts = tile_scheduler_metadata.size(0) # Copy to persistent buffer for full-CG support tile_scheduler_metadata_buffer = \ self.tile_scheduler_metadata_buffer[:num_sm_parts] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) self.num_splits_buffer.copy_(num_splits) fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( scheduler_metadata=tile_scheduler_metadata_buffer, num_splits=self.num_splits_buffer, # cache_lens and block_table are basically unused in sparse case # but the decode kernel will treat -1 and indices >= cache_lens # as invalid so we make sure cache_lens is large enough to not # accidentally mark indices invalid, we will use -1 exclusively # to mark invalid indices cache_lens=self.max_model_len_tensor, dummy_block_table=self.dummy_block_table) # Add biren attention params query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens if common_attn_metadata.seq_start_loc is None: if len(seq_lens) > 8: seq_lens_cpu = seq_lens.cpu() seq_start_loc = torch.tensor( [0] + list(itertools.accumulate(seq_lens_cpu)), device=query_start_loc.device, dtype=torch.int32) else: seq_start_loc = torch.tensor( [0] + list(itertools.accumulate(seq_lens)), device=query_start_loc.device, dtype=torch.int32) else: seq_start_loc = common_attn_metadata.seq_start_loc if common_attn_metadata.context_lens is None: context_lens = seq_lens - (query_start_loc[1:] - query_start_loc[:-1]) else: context_lens = common_attn_metadata.context_lens if common_attn_metadata.max_decode_seq_len is None: max_decode_seq_len = max_decode_seq_len = int( seq_lens.max().item()) else: max_decode_seq_len = common_attn_metadata.max_decode_seq_len num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens metadata = SupaFlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, max_seq_len=common_attn_metadata.max_seq_len, num_actual_tokens=common_attn_metadata.num_actual_tokens, query_start_loc=common_attn_metadata.query_start_loc, slot_mapping=common_attn_metadata.slot_mapping, block_table=common_attn_metadata.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, fp8_extra_metadata=fp8_extra_metadata, seq_start_loc=seq_start_loc, context_lens=context_lens, max_decode_seq_len=max_decode_seq_len, num_prefills=num_prefills, num_decodes=num_decodes, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, ) return metadata class SupaFlashMLASparseBackend(FlashMLASparseBackend): @staticmethod def get_name() -> str: return "SUPA_FLASHMLA_SPARSE_VLLM_V1" @staticmethod def get_metadata_cls() -> type[AttentionMetadata]: return SupaFlashMLASparseMetadata @staticmethod def get_builder_cls() -> type["SupaFlashMLASparseMetadataBuilder"]: return SupaFlashMLASparseMetadataBuilder @staticmethod def get_impl_cls() -> type["SupaFlashMLASparseImpl"]: return SupaFlashMLASparseImpl @staticmethod def get_kv_cache_usharp_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: th_gran = SupaFlashMLASparseBackend.get_kv_cache_usharp_alignment( block_size) n_block = max(1, (num_blocks + th_gran - 1) // th_gran) logger.debug( f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004 ) return (2, n_block, th_gran * block_size, num_kv_heads * head_size) @staticmethod def get_kv_cache_usharp_alignment(block_size: int) -> int: max_h_limit = 2048 return max_h_limit // block_size class SupaFlashMLASparseImpl(FlashMLASparseImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments topk_indice_buffer: Optional[torch.Tensor] = None, indexer: Optional["Indexer"] = None, **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, topk_indice_buffer, indexer, **mla_args) def _forward_bf16_kv( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: SupaFlashMLASparseMetadata) -> torch.Tensor: bsz = 1 seq_len_q, num_heads, _ = q.shape # topk_indices = topk_indices.unsqueeze(0) index_mask = torch.full((bsz, seq_len_q, seq_len_q), 1, dtype=torch.int32, device=q.device) # .scatter_(-1, valid_mask.to(torch.int64), 0).to(torch.int32).supa() for idx_bsz in range(bsz): for idx_q in range(seq_len_q): for idx_k in range(topk_indices.shape[-1]): target_idx = topk_indices[idx_q][idx_k] if target_idx >= 0 and target_idx < seq_len_q: index_mask[idx_bsz][idx_q][topk_indices[idx_q] [idx_k]] = 0 query = q.transpose(0, 1).contiguous() # [num_heads, seq_len, head_dim] # output is always [1, seq_len, num_heads * head_dim] however query;s shape is output = torch_br.supa_flash_attn_cache_infer( query, kv_c_and_k_pe_cache[: 1], # [1, num_blocks, block_szie,self.head_size] attn_metadata.query_start_loc, attn_metadata.seq_start_loc, attn_metadata.block_table, attn_metadata.context_lens, attn_metadata.slot_mapping, attn_metadata.max_seq_len, self.head_size, softmax_scale=self.softmax_scale, v_head_size=self.kv_lora_rank, mask=index_mask) output = output.reshape(seq_len_q, num_heads, self.kv_lora_rank).contiguous() return output def forward( self, layer: AttentionLayer, q: torch.Tensor, k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: SupaFlashMLASparseMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # MQA 576/512 approach for both prefill and decode assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLACommonImpl") if attn_metadata is None: # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) ql_nope = torch.bmm(q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) topk_indices = self.topk_indices_buffer[:num_actual_toks] # TODO: handle index / kv_cache correctly # topk_indices_global = triton_convert_req_index_to_global_index( # attn_metadata.req_id_per_token, # attn_metadata.block_table, # topk_indices, # BLOCK_SIZE=attn_metadata.block_size, # NUM_TOPK_TOKENS=attn_metadata.topk_tokens, # ) q = torch.cat([ql_nope, q_pe], dim=-1) # write the latent and rope to kv cache if kv_cache.numel() > 0: _, num_blocks, block_size, head_size = kv_cache.shape k_pe_tmp = k_pe.squeeze(1).unsqueeze(0) key_supa = torch.cat([k_c_normed, k_pe_tmp], dim=2) torch_br.supa_kvcache_store_infer_v2(kv_cache, key_supa, key_supa, attn_metadata.slot_mapping, head_size) if self.kv_cache_dtype != "fp8_ds_mla": attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) else: raise RuntimeError("Not support fp8 on br.") self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output