Files
enginex-biren-vllm/vllm_br/v1/attention/backends/attention_v1.py

658 lines
27 KiB
Python
Raw Normal View History

2026-03-10 13:31:25 +08:00
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
import torch
import torch_br
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.config import VllmConfig
from vllm.logger import logger
from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_br.config.compilation import SUPAGraphMode
if TYPE_CHECKING:
pass
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
class SUPAFlashAttentionBackend(AttentionBackend):
# NOTE: When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
# NOTE: currently, we do not support accept_output_buffer=True
accept_output_buffer: bool = False
supports_quant_query_input: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "SUPAFLASH_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["SUPAFlashAttentionImpl"]:
return SUPAFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["SUPAFlashAttentionMetadata"]:
return SUPAFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> type["SUPAFlashAttentionMetadataBuilder"]:
return SUPAFlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_usharp_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
block_size)
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
logger.debug(
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
)
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
@staticmethod
def get_kv_cache_usharp_alignment(block_size: int) -> int:
max_h_limit = 2048
return max_h_limit // block_size
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class SUPAFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# BIREN Attention Params
seq_start_loc: torch.Tensor
context_lens: torch.Tensor
max_decode_seq_len: int
do_cache: bool # when use attentionsplit, do cache = False
num_actual_reqs: torch.Tensor
# Graph mode
supagraph_runtime_mode: SUPAGraphMode
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
causal: bool = True
# for local attention
# @dataclass
# class LocalAttentionMetadata:
# local_query_start_loc: torch.Tensor
# local_seqused_k: torch.Tensor
# local_block_table: torch.Tensor
# local_max_query_len: int
# local_max_seq_len: int
# local_scheduler_metadata: Optional[torch.Tensor]
# local_attn_metadata: Optional[LocalAttentionMetadata] = None
class SUPAFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[SUPAFlashAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.kv_cache_dtype = kv_cache_spec.dtype
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
supports_spec_as_decode = True
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
self.max_num_splits = 0 # No upper bound on the number of splits.
# self.aot_schedule = (get_flash_attn_version() == 3)
self.aot_schedule = False
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.max_cudagraph_size = self.compilation_config.max_capture_size
# if self.use_full_cuda_graph and self.aot_schedule:
# if self.max_cudagraph_size > 992:
# # This condition derives from FA3's internal heuristic.
# # TODO(woosuk): Support larger cudagraph sizes.
# raise ValueError(
# "Capture size larger than 992 is not supported for "
# "full cuda graph.")
# self.scheduler_metadata = torch.zeros(
# vllm_config.scheduler_config.max_num_seqs + 1,
# dtype=torch.int32,
# device=self.device,
# )
# # When using cuda graph, we need to set the upper bound of the
# # number of splits so that large enough intermediate buffers are
# # pre-allocated during capture.
# self.max_num_splits = (
# envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
# model_config = runner.model_config
# self.runner = runner
# self.num_heads_q = model_config.get_num_attention_heads(
# runner.parallel_config)
# self.num_heads_kv = model_config.get_num_kv_heads(
# runner.parallel_config)
# self.headdim = model_config.get_head_size()
# self.block_size = kv_cache_spec.block_size
# self.kv_cache_spec = kv_cache_spec
# self.block_table = block_table
# self.aot_schedule = False
# logger.warning(
# "AOT Schedule is disabled when using SUPAFlashAttention.")
# # Sliding window size to be used with the AOT scheduler will be
# # populated on first build() call.
# self.aot_sliding_window: Optional[tuple[int, int]] = None
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> SUPAFlashAttentionMetadata:
"""
fast_build disables AOT scheduling, used when there will be few
iterations i.e. spec-decode
"""
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True)
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
num_actual_reqs = common_attn_metadata.num_actual_reqs
seq_start_loc = common_attn_metadata.seq_start_loc
context_lens = common_attn_metadata.context_lens
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self.aot_schedule and not fast_build
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False
aot_schedule = False
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
if self.use_full_cuda_graph and \
num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.aot_schedule:
raise NotImplementedError(
'aot schedule not support in SUPA attention')
return None
# for local attention
# local_attn_metadata = None
# if self.runner.attention_chunk_size is not None:
# seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
# virt_block_table_tensor = make_local_attention_virtual_batches(
# self.runner.attention_chunk_size,
# self.runner.query_start_loc_np[:num_reqs + 1],
# self.runner.seq_lens_np[:num_reqs],
# block_table_tensor,
# self.block_size,
# )
# local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
# self.runner.device, non_blocking=False)
# local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
# self.runner.device, non_blocking=False)
# local_max_query_len = seqlens_q_local_np.max()
# local_max_seq_len = virt_k_seqlens_np.max()
# local_scheduler_metadata = schedule(
# batch_size=local_query_start_loc.shape[0] - 1,
# cu_query_lens=local_query_start_loc,
# max_query_len=local_max_query_len,
# seqlens=local_seqused_k,
# max_seq_len=local_max_seq_len,
# causal=True)
# local_attn_metadata = SUPAFlashAttentionMetadata.LocalAttentionMetadata(
# local_query_start_loc=local_query_start_loc,
# local_seqused_k=local_seqused_k,
# local_block_table=virt_block_table_tensor,
# local_max_query_len=local_max_query_len,
# local_max_seq_len=local_max_seq_len,
# local_scheduler_metadata=local_scheduler_metadata,
# )
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
self.device, non_blocking=True)
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=suffix_kv_lens,
max_seq_len=max_seq_len -
common_prefix_len,
causal=True)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=causal)
if common_attn_metadata.max_decode_seq_len is None:
max_decode_seq_len = max_decode_seq_len = int(
seq_lens.max().item())
else:
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
attn_metadata = SUPAFlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
# local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
causal=causal,
# Biren Attention Params
seq_start_loc=seq_start_loc,
context_lens=context_lens,
max_decode_seq_len=max_decode_seq_len,
num_prefills=num_prefills,
num_decodes=num_decodes,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
do_cache=True,
num_actual_reqs=num_actual_reqs,
supagraph_runtime_mode=common_attn_metadata.supagraph_runtime_mode)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class SUPAFlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="cpu")
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window or None
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
SUPAFlashAttentionBackend.validate_head_size(head_size)
self.attn_type = attn_type
if attn_type not in (AttentionType.DECODER,
AttentionType.ENCODER_ONLY):
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
if sinks.shape[0] != num_heads:
raise ValueError(
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}.")
if sinks.dtype != torch.float32:
raise ValueError("Sinks must be of type float32, but got "
f"{sinks.dtype}.")
self.sinks = sinks
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: SUPAFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is None, "Output tensor should not provided."
if attn_metadata is None:
# FIXME: this may lead to wrong block estimatation
# Profiling run.
return query
is_encoder = self.attn_type in (AttentionType.ENCODER_ONLY,
AttentionType.ENCODER)
# NOTE: supa attn use [batch_size, num_tokens, num_heads * head_size] as shape
if kv_cache is not None and attn_metadata.do_cache and not is_encoder:
torch_br.supa_kvcache_store_infer_v2(
kv_cache,
key,
value, # type: ignore
attn_metadata.slot_mapping,
self.head_size)
if self.sinks is not None:
return self.forward_sw_sinks(query, kv_cache, attn_metadata)
if self.attn_type in (AttentionType.ENCODER_ONLY,
AttentionType.ENCODER):
assert len(query.shape) == 3
return torch_br.supa_flash_attention_infer( # type: ignore
query,
key,
value,
attn_metadata.query_start_loc,
self.head_size,
len(attn_metadata.query_start_loc), # type: ignore
self.alibi_slopes,
softmax_scale=self.scale,
is_causal=_get_causal_option(self.attn_type))
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.supagraph_runtime_mode is None or (
attn_metadata.supagraph_runtime_mode
in (SUPAGraphMode.NONE, SUPAGraphMode.FULL_DECODE_ONLY)):
# prefill + decode(non-mtp)
if num_prefill_tokens > 0:
output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
query,
kv_cache,
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
self.head_size,
alibi_slopes=self.alibi_slopes,
softmax_scale=self.scale,
num_reqs=attn_metadata.num_actual_reqs)
return output_prefill
## decode only
output_decode = torch_br.supa_attention_decoder_infer_v2( # type: ignore
query, # type: ignore
kv_cache,
attn_metadata.block_table,
attn_metadata.seq_lens,
attn_metadata.max_decode_seq_len,
self.head_size,
attn_metadata.num_prefills,
self.alibi_slopes,
softmax_scale=self.scale)
return output_decode
else:
output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
query,
kv_cache,
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
self.head_size,
alibi_slopes=self.alibi_slopes,
softmax_scale=self.scale,
num_reqs=attn_metadata.num_actual_reqs)
return output_prefill
# sliding window with sinks impl
def forward_sw_sinks(
self,
query: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: SUPAFlashAttentionMetadata,
) -> torch.Tensor:
# prefix-enabled attention
output = torch_br.supa_flash_attn_cache_infer( # type: ignore
query,
kv_cache,
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
attn_metadata.context_lens,
attn_metadata.slot_mapping,
attn_metadata.max_seq_len,
self.head_size,
window_size=self.sliding_window,
sinks=self.sinks)
return output
def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)