Files
enginex-biren-vllm/vllm_br/v1/attention/backends/mla/flashmla_sparse.py
2026-03-10 13:31:25 +08:00

451 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import itertools
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
import numpy as np
import torch
import torch_br
from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata
from vllm.attention.ops.flashmla import get_mla_metadata
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseImpl, FlashMLASparseMetadata,
FlashMLASparseMetadataBuilder)
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
_NO_DEFAULT = object()
@dataclass
class SupaFlashMLASparseMetadata(FlashMLASparseMetadata):
# BIREN Attention Params
seq_start_loc: torch.Tensor = _NO_DEFAULT
context_lens: torch.Tensor = _NO_DEFAULT
max_decode_seq_len: int = -1
num_prefills: int = -1
num_decodes: int = -1
num_prefill_tokens: int = -1
num_decode_tokens: int = -1
def __post_init__(self):
if self.seq_start_loc is _NO_DEFAULT or self.context_lens is _NO_DEFAULT or \
self.max_decode_seq_len == -1 or self.num_prefills == -1 or \
self.num_decodes == -1 or self.num_prefill_tokens == -1 or \
self.num_decode_tokens == -1:
raise TypeError("__init__ missing required argument")
class SupaFlashMLASparseMetadataBuilder(FlashMLASparseMetadataBuilder):
reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(
kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device,
)
self.vllm_config = vllm_config
self.num_speculative_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0)
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""On SUPA, we want prefill at front and decode at back.
"""
# TODO update doc
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens - num_spec_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# TODO update doc
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
# for i in range(1, min(num_decodes, num_prefills) + 1):
# # If the decode is at the "back" of the batch, i, we can swap it
# # with the prefill closest to the front of the batch
# decode_idx = decodes[num_decodes - i]
# if decode_idx < num_decodes:
# break
# input_batch.swap_states(prefills[i - 1], decode_idx)
# modified_batch = True
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
prefills_idx = prefills[num_prefills - i]
if prefills_idx < num_prefills:
break
input_batch.swap_states(decodes[i - 1], prefills_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> SupaFlashMLASparseMetadata:
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata = None
if self.use_fp8_kv_cache:
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor,
num_q_tokens_per_head_k=num_tokens * self.num_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
num_sm_parts = tile_scheduler_metadata.size(0)
# Copy to persistent buffer for full-CG support
tile_scheduler_metadata_buffer = \
self.tile_scheduler_metadata_buffer[:num_sm_parts]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
self.num_splits_buffer.copy_(num_splits)
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=self.num_splits_buffer,
# cache_lens and block_table are basically unused in sparse case
# but the decode kernel will treat -1 and indices >= cache_lens
# as invalid so we make sure cache_lens is large enough to not
# accidentally mark indices invalid, we will use -1 exclusively
# to mark invalid indices
cache_lens=self.max_model_len_tensor,
dummy_block_table=self.dummy_block_table)
# Add biren attention params
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
if common_attn_metadata.seq_start_loc is None:
if len(seq_lens) > 8:
seq_lens_cpu = seq_lens.cpu()
seq_start_loc = torch.tensor(
[0] + list(itertools.accumulate(seq_lens_cpu)),
device=query_start_loc.device,
dtype=torch.int32)
else:
seq_start_loc = torch.tensor(
[0] + list(itertools.accumulate(seq_lens)),
device=query_start_loc.device,
dtype=torch.int32)
else:
seq_start_loc = common_attn_metadata.seq_start_loc
if common_attn_metadata.context_lens is None:
context_lens = seq_lens - (query_start_loc[1:] -
query_start_loc[:-1])
else:
context_lens = common_attn_metadata.context_lens
if common_attn_metadata.max_decode_seq_len is None:
max_decode_seq_len = max_decode_seq_len = int(
seq_lens.max().item())
else:
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
metadata = SupaFlashMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
seq_start_loc=seq_start_loc,
context_lens=context_lens,
max_decode_seq_len=max_decode_seq_len,
num_prefills=num_prefills,
num_decodes=num_decodes,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
)
return metadata
class SupaFlashMLASparseBackend(FlashMLASparseBackend):
@staticmethod
def get_name() -> str:
return "SUPA_FLASHMLA_SPARSE_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
return SupaFlashMLASparseMetadata
@staticmethod
def get_builder_cls() -> type["SupaFlashMLASparseMetadataBuilder"]:
return SupaFlashMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["SupaFlashMLASparseImpl"]:
return SupaFlashMLASparseImpl
@staticmethod
def get_kv_cache_usharp_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
th_gran = SupaFlashMLASparseBackend.get_kv_cache_usharp_alignment(
block_size)
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
logger.debug(
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
)
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
@staticmethod
def get_kv_cache_usharp_alignment(block_size: int) -> int:
max_h_limit = 2048
return max_h_limit // block_size
class SupaFlashMLASparseImpl(FlashMLASparseImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
topk_indice_buffer: Optional[torch.Tensor] = None,
indexer: Optional["Indexer"] = None,
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, topk_indice_buffer,
indexer, **mla_args)
def _forward_bf16_kv(
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: SupaFlashMLASparseMetadata) -> torch.Tensor:
bsz = 1
seq_len_q, num_heads, _ = q.shape
# topk_indices = topk_indices.unsqueeze(0)
index_mask = torch.full((bsz, seq_len_q, seq_len_q),
1,
dtype=torch.int32,
device=q.device)
# .scatter_(-1, valid_mask.to(torch.int64), 0).to(torch.int32).supa()
for idx_bsz in range(bsz):
for idx_q in range(seq_len_q):
for idx_k in range(topk_indices.shape[-1]):
target_idx = topk_indices[idx_q][idx_k]
if target_idx >= 0 and target_idx < seq_len_q:
index_mask[idx_bsz][idx_q][topk_indices[idx_q]
[idx_k]] = 0
query = q.transpose(0,
1).contiguous() # [num_heads, seq_len, head_dim]
# output is always [1, seq_len, num_heads * head_dim] however query;s shape is
output = torch_br.supa_flash_attn_cache_infer(
query,
kv_c_and_k_pe_cache[:
1], # [1, num_blocks, block_szieself.head_size]
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
attn_metadata.context_lens,
attn_metadata.slot_mapping,
attn_metadata.max_seq_len,
self.head_size,
softmax_scale=self.softmax_scale,
v_head_size=self.kv_lora_rank,
mask=index_mask)
output = output.reshape(seq_len_q, num_heads,
self.kv_lora_rank).contiguous()
return output
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: SupaFlashMLASparseMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
topk_indices = self.topk_indices_buffer[:num_actual_toks]
# TODO: handle index / kv_cache correctly
# topk_indices_global = triton_convert_req_index_to_global_index(
# attn_metadata.req_id_per_token,
# attn_metadata.block_table,
# topk_indices,
# BLOCK_SIZE=attn_metadata.block_size,
# NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
# )
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
_, num_blocks, block_size, head_size = kv_cache.shape
k_pe_tmp = k_pe.squeeze(1).unsqueeze(0)
key_supa = torch.cat([k_c_normed, k_pe_tmp], dim=2)
torch_br.supa_kvcache_store_infer_v2(kv_cache, key_supa, key_supa,
attn_metadata.slot_mapping,
head_size)
if self.kv_cache_dtype != "fp8_ds_mla":
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
attn_metadata)
else:
raise RuntimeError("Not support fp8 on br.")
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output