Files
enginex-biren-vllm/vllm_br/v0/attention/backends/attention_v0.py
2026-03-10 13:31:25 +08:00

571 lines
23 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
"""Attention layer with FlashAttention."""
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
import torch_br
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import logger
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder)
from collections import defaultdict
from itertools import accumulate
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
class SUPAFlashAttentionBackend(AttentionBackend):
# NOTE: When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
# NOTE: currently, we do not support accept_output_buffer=True
accept_output_buffer: bool = False
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SUPAFLASH_ATTN_VLLM_V0"
@staticmethod
def get_impl_cls() -> type["SUPAFlashAttentionImpl"]:
return SUPAFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["SUPAFlashAttentionMetadata"]:
return SUPAFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> type["SUPAFlashAttentionMetadataBuilder"]:
return SUPAFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_usharp_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
block_size)
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
logger.debug(
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
)
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
@staticmethod
def get_kv_cache_usharp_alignment(block_size: int) -> int:
max_h_limit = 2048
return max_h_limit // block_size
@dataclass
class SUPAFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
seq_lens_tensor: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# BIREN Attention Params
seq_start_loc: torch.Tensor
context_lens: torch.Tensor
max_decode_seq_len: int
num_prefills: int
num_decodes: int
num_prefills_tokens: int
do_cache: bool # when use attentionsplit, do cache = False
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["SUPAFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["SUPAFlashAttentionMetadata"] = None
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
@property
def do_prefill(self) -> bool:
return self.num_prefills > 0
@property
def do_decode(self) -> bool:
return self.num_decodes > 0
@property
def prefill_metadata(self) -> Optional["SUPAFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
else:
return None
class SUPAFlashAttentionMetadataBuilder:
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens,
inter_data.seq_lens,
inter_data.query_lens,
inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
strict=False):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
# decode_query_lens = query_lens[self.num_prefills:]
# if len(decode_query_lens) > 0:
# max_decode_query_len = max(decode_query_lens)
# else:
# max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
return SUPAFlashAttentionMetadata(
num_actual_tokens=batch_size,
max_query_len=max_query_len,
query_start_loc=query_start_loc_tensor,
max_seq_len=max_prefill_seq_len,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
block_table=block_tables,
slot_mapping=slot_mapping_tensor,
use_cascade=False,
common_prefix_len=0,
scheduler_metadata=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
local_attn_metadata=None,
prefix_scheduler_metadata=None,
# Biren Attention Params
seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor,
max_decode_seq_len=max_decode_seq_len,
num_prefills=self.num_prefills,
num_decodes=num_decode_tokens,
num_prefills_tokens=self.num_prefill_tokens,
do_cache=False)
class SUPAFlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.attn_type = attn_type
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = SUPAFlashAttentionBackend.get_supported_head_sizes(
)
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=1 to use another attention backend.")
self.use_irope = use_irope
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: SUPAFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is None, "Output tensor should not provided."
if attn_metadata is None:
# FIXME: this may lead to wrong block estimatation
# Profiling run.
return query
# NOTE: supa attn use [batch_size, num_tokens, num_heads * head_size] as shape
if kv_cache is not None and attn_metadata.do_cache:
torch_br.supa_kvcache_store_infer_v2(
kv_cache,
key,
value, # type: ignore
attn_metadata.slot_mapping,
self.head_size)
output_prefill = output_decode = None
output = torch.empty_like(query)
if attn_metadata.do_prefill and attn_metadata.do_decode:
# chunked
decode_query = query[:, attn_metadata.num_prefills_tokens:]
query = query[:, :attn_metadata.num_prefills_tokens]
key = key[:, :attn_metadata.num_prefills_tokens]
value = value[:, :attn_metadata.num_prefills_tokens]
elif attn_metadata.do_decode:
decode_query = query
if attn_metadata.do_prefill:
if (kv_cache is None or attn_metadata.block_table.numel() == 0):
# has do_decode should go into prefix-enabled branch
assert not attn_metadata.do_decode
# in this branch, query_start_loc = seq_start_loc
if os.getenv('USE_BR_SUEAGER_SDPA',
'False').lower() not in {'false', '0', ''}:
output_prefill, inter_mediate = torch_br.sueager_scaled_dot_product_attention_fwd(
query=query,
key=key,
value=value,
mask=None,
dropout_prob=0.0,
is_causal=_get_causal_option(self.attn_type),
scale=self.scale,
algorithm="FMHA",
)
output_prefill = torch_br.supa_shape_transform_qkv(
output_prefill, 1, query.shape[1], self.num_kv_heads,
self.head_size)
else:
output_prefill = torch_br.supa_flash_attention_infer( # type: ignore
query,
key,
value,
attn_metadata.query_start_loc,
self.head_size,
len(attn_metadata.query_start_loc), # type: ignore
self.alibi_slopes,
softmax_scale=self.scale,
is_causal=_get_causal_option(self.attn_type))
else:
# prefix-enabled attention
output_prefill = torch_br.supa_flash_attn_cache_infer( # type: ignore
query,
kv_cache,
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
attn_metadata.context_lens,
attn_metadata.slot_mapping,
attn_metadata.max_seq_len,
self.head_size,
self.alibi_slopes,
softmax_scale=self.scale)
if attn_metadata.do_decode:
output_decode = torch_br.supa_attention_decoder_infer_v2( # type: ignore
decode_query, # type: ignore
kv_cache,
attn_metadata.block_table,
attn_metadata.seq_lens,
attn_metadata.max_decode_seq_len,
self.head_size,
attn_metadata.num_prefills,
self.alibi_slopes,
softmax_scale=self.scale)
if attn_metadata.do_prefill and attn_metadata.do_decode:
output[:, :attn_metadata.num_prefills_tokens] = output_prefill
output[:, attn_metadata.num_prefills_tokens:] = output_decode
elif attn_metadata.do_prefill:
output = output_prefill
else:
output = output_decode
return output
def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)