first commit
This commit is contained in:
17
vllm_br/v1/attention/backends/__init__.py
Normal file
17
vllm_br/v1/attention/backends/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from . import mla # noqa: F401
|
||||
from .utils import *
|
||||
Binary file not shown.
Binary file not shown.
BIN
vllm_br/v1/attention/backends/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm_br/v1/attention/backends/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
657
vllm_br/v1/attention/backends/attention_v1.py
Normal file
657
vllm_br/v1/attention/backends/attention_v1.py
Normal file
@@ -0,0 +1,657 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
class SUPAFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# NOTE: When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
# NOTE: currently, we do not support accept_output_buffer=True
|
||||
accept_output_buffer: bool = False
|
||||
supports_quant_query_input: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SUPAFLASH_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["SUPAFlashAttentionImpl"]:
|
||||
return SUPAFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["SUPAFlashAttentionMetadata"]:
|
||||
return SUPAFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["SUPAFlashAttentionMetadataBuilder"]:
|
||||
return SUPAFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
|
||||
block_size)
|
||||
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||||
logger.debug(
|
||||
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||||
)
|
||||
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||||
max_h_limit = 2048
|
||||
return max_h_limit // block_size
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
@staticmethod
|
||||
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
return torch.float8_e4m3fn
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SUPAFlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# BIREN Attention Params
|
||||
seq_start_loc: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
max_decode_seq_len: int
|
||||
do_cache: bool # when use attentionsplit, do cache = False
|
||||
num_actual_reqs: torch.Tensor
|
||||
|
||||
# Graph mode
|
||||
supagraph_runtime_mode: SUPAGraphMode
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# for local attention
|
||||
# @dataclass
|
||||
# class LocalAttentionMetadata:
|
||||
# local_query_start_loc: torch.Tensor
|
||||
# local_seqused_k: torch.Tensor
|
||||
# local_block_table: torch.Tensor
|
||||
# local_max_query_len: int
|
||||
# local_max_seq_len: int
|
||||
# local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
# local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
class SUPAFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[SUPAFlashAttentionMetadata]):
|
||||
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
self.kv_cache_dtype = kv_cache_spec.dtype
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
supports_spec_as_decode = True
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
# self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.aot_schedule = False
|
||||
|
||||
self.use_full_cuda_graph = \
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
||||
|
||||
# if self.use_full_cuda_graph and self.aot_schedule:
|
||||
# if self.max_cudagraph_size > 992:
|
||||
# # This condition derives from FA3's internal heuristic.
|
||||
# # TODO(woosuk): Support larger cudagraph sizes.
|
||||
# raise ValueError(
|
||||
# "Capture size larger than 992 is not supported for "
|
||||
# "full cuda graph.")
|
||||
|
||||
# self.scheduler_metadata = torch.zeros(
|
||||
# vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
# dtype=torch.int32,
|
||||
# device=self.device,
|
||||
# )
|
||||
# # When using cuda graph, we need to set the upper bound of the
|
||||
# # number of splits so that large enough intermediate buffers are
|
||||
# # pre-allocated during capture.
|
||||
# self.max_num_splits = (
|
||||
# envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
# model_config = runner.model_config
|
||||
|
||||
# self.runner = runner
|
||||
# self.num_heads_q = model_config.get_num_attention_heads(
|
||||
# runner.parallel_config)
|
||||
# self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
# runner.parallel_config)
|
||||
# self.headdim = model_config.get_head_size()
|
||||
# self.block_size = kv_cache_spec.block_size
|
||||
# self.kv_cache_spec = kv_cache_spec
|
||||
# self.block_table = block_table
|
||||
|
||||
# self.aot_schedule = False
|
||||
# logger.warning(
|
||||
# "AOT Schedule is disabled when using SUPAFlashAttention.")
|
||||
|
||||
# # Sliding window size to be used with the AOT scheduler will be
|
||||
# # populated on first build() call.
|
||||
# self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> SUPAFlashAttentionMetadata:
|
||||
"""
|
||||
fast_build disables AOT scheduling, used when there will be few
|
||||
iterations i.e. spec-decode
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True)
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
num_actual_reqs = common_attn_metadata.num_actual_reqs
|
||||
seq_start_loc = common_attn_metadata.seq_start_loc
|
||||
context_lens = common_attn_metadata.context_lens
|
||||
|
||||
# the overhead of the aot schedule is not worth it for spec-decode
|
||||
aot_schedule = self.aot_schedule and not fast_build
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
# For the AOT scheduler we need the sliding window value to be
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
aot_schedule = False
|
||||
|
||||
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
|
||||
if self.use_full_cuda_graph and \
|
||||
num_actual_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.aot_schedule:
|
||||
raise NotImplementedError(
|
||||
'aot schedule not support in SUPA attention')
|
||||
return None
|
||||
|
||||
# for local attention
|
||||
# local_attn_metadata = None
|
||||
# if self.runner.attention_chunk_size is not None:
|
||||
# seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
# virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
# self.runner.attention_chunk_size,
|
||||
# self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
# self.runner.seq_lens_np[:num_reqs],
|
||||
# block_table_tensor,
|
||||
# self.block_size,
|
||||
# )
|
||||
# local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
# self.runner.device, non_blocking=False)
|
||||
# local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
# self.runner.device, non_blocking=False)
|
||||
# local_max_query_len = seqlens_q_local_np.max()
|
||||
# local_max_seq_len = virt_k_seqlens_np.max()
|
||||
# local_scheduler_metadata = schedule(
|
||||
# batch_size=local_query_start_loc.shape[0] - 1,
|
||||
# cu_query_lens=local_query_start_loc,
|
||||
# max_query_len=local_max_query_len,
|
||||
# seqlens=local_seqused_k,
|
||||
# max_seq_len=local_max_seq_len,
|
||||
# causal=True)
|
||||
|
||||
# local_attn_metadata = SUPAFlashAttentionMetadata.LocalAttentionMetadata(
|
||||
# local_query_start_loc=local_query_start_loc,
|
||||
# local_seqused_k=local_seqused_k,
|
||||
# local_block_table=virt_block_table_tensor,
|
||||
# local_max_query_len=local_max_query_len,
|
||||
# local_max_seq_len=local_max_seq_len,
|
||||
# local_scheduler_metadata=local_scheduler_metadata,
|
||||
# )
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
|
||||
self.device, non_blocking=True)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=suffix_kv_lens,
|
||||
max_seq_len=max_seq_len -
|
||||
common_prefix_len,
|
||||
causal=True)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=causal)
|
||||
|
||||
if common_attn_metadata.max_decode_seq_len is None:
|
||||
max_decode_seq_len = max_decode_seq_len = int(
|
||||
seq_lens.max().item())
|
||||
else:
|
||||
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
|
||||
|
||||
attn_metadata = SUPAFlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
# local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
causal=causal,
|
||||
# Biren Attention Params
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
do_cache=True,
|
||||
num_actual_reqs=num_actual_reqs,
|
||||
supagraph_runtime_mode=common_attn_metadata.supagraph_runtime_mode)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class SUPAFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window or None
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
SUPAFlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
self.attn_type = attn_type
|
||||
|
||||
if attn_type not in (AttentionType.DECODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||
|
||||
self.sinks: Optional[torch.Tensor] = None
|
||||
if sinks is not None:
|
||||
if sinks.shape[0] != num_heads:
|
||||
raise ValueError(
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Expected {num_heads}, but got "
|
||||
f"{sinks.shape[0]}.")
|
||||
if sinks.dtype != torch.float32:
|
||||
raise ValueError("Sinks must be of type float32, but got "
|
||||
f"{sinks.dtype}.")
|
||||
self.sinks = sinks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: SUPAFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is None, "Output tensor should not provided."
|
||||
if attn_metadata is None:
|
||||
# FIXME: this may lead to wrong block estimatation
|
||||
# Profiling run.
|
||||
return query
|
||||
|
||||
is_encoder = self.attn_type in (AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER)
|
||||
# NOTE: supa attn use [batch_size, num_tokens, num_heads * head_size] as shape
|
||||
if kv_cache is not None and attn_metadata.do_cache and not is_encoder:
|
||||
torch_br.supa_kvcache_store_infer_v2(
|
||||
kv_cache,
|
||||
key,
|
||||
value, # type: ignore
|
||||
attn_metadata.slot_mapping,
|
||||
self.head_size)
|
||||
|
||||
if self.sinks is not None:
|
||||
return self.forward_sw_sinks(query, kv_cache, attn_metadata)
|
||||
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER):
|
||||
assert len(query.shape) == 3
|
||||
return torch_br.supa_flash_attention_infer( # type: ignore
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_metadata.query_start_loc,
|
||||
self.head_size,
|
||||
len(attn_metadata.query_start_loc), # type: ignore
|
||||
self.alibi_slopes,
|
||||
softmax_scale=self.scale,
|
||||
is_causal=_get_causal_option(self.attn_type))
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
if attn_metadata.supagraph_runtime_mode is None or (
|
||||
attn_metadata.supagraph_runtime_mode
|
||||
in (SUPAGraphMode.NONE, SUPAGraphMode.FULL_DECODE_ONLY)):
|
||||
# prefill + decode(non-mtp)
|
||||
if num_prefill_tokens > 0:
|
||||
output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
self.head_size,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_scale=self.scale,
|
||||
num_reqs=attn_metadata.num_actual_reqs)
|
||||
return output_prefill
|
||||
## decode only
|
||||
output_decode = torch_br.supa_attention_decoder_infer_v2( # type: ignore
|
||||
query, # type: ignore
|
||||
kv_cache,
|
||||
attn_metadata.block_table,
|
||||
attn_metadata.seq_lens,
|
||||
attn_metadata.max_decode_seq_len,
|
||||
self.head_size,
|
||||
attn_metadata.num_prefills,
|
||||
self.alibi_slopes,
|
||||
softmax_scale=self.scale)
|
||||
return output_decode
|
||||
else:
|
||||
output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
self.head_size,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_scale=self.scale,
|
||||
num_reqs=attn_metadata.num_actual_reqs)
|
||||
return output_prefill
|
||||
|
||||
# sliding window with sinks impl
|
||||
def forward_sw_sinks(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: SUPAFlashAttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# prefix-enabled attention
|
||||
output = torch_br.supa_flash_attn_cache_infer( # type: ignore
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.max_seq_len,
|
||||
self.head_size,
|
||||
window_size=self.sliding_window,
|
||||
sinks=self.sinks)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _get_causal_option(attn_type: str) -> bool:
|
||||
"""
|
||||
Determine whether the given attention type is suitable for causal
|
||||
attention mechanisms.
|
||||
|
||||
Args:
|
||||
attn_type (AttentionType): The type of attention being evaluated
|
||||
|
||||
Returns:
|
||||
bool: Returns `True` if the attention type is suitable for causal
|
||||
attention (i.e., not encoder, encoder-only, or encoder-decoder),
|
||||
otherwise returns `False`.
|
||||
"""
|
||||
return not (attn_type == AttentionType.ENCODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY
|
||||
or attn_type == AttentionType.ENCODER_DECODER)
|
||||
19
vllm_br/v1/attention/backends/mla/__init__.py
Normal file
19
vllm_br/v1/attention/backends/mla/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import flashmla # noqa: F401
|
||||
from . import flashmla_sparse # noqa: F401
|
||||
from . import indexer # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
657
vllm_br/v1/attention/backends/mla/flashmla.py
Normal file
657
vllm_br/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,657 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonImpl,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.mla.flashmla import (FlashMLABackend,
|
||||
FlashMLAMetadata)
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm_br import envs
|
||||
from vllm_br.model_executor.layers.br_utils import _convert_to_numa_tensor
|
||||
from vllm_br.utils import get_grandparent_pid
|
||||
from vllm_br.v1.attention.backends.utils import SUPACommonAttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SupaFlashMLABackend(FlashMLABackend):
|
||||
|
||||
# NOTE: When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
# NOTE: currently, we do not support accept_output_buffer=True
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SUPAFLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["SupaFlashMLAMetadata"]:
|
||||
return SupaFlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["SupaFlashMLAMetadataBuilder"]:
|
||||
return SupaFlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["SupaFlashMLAImpl"]:
|
||||
return SupaFlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
th_gran = SupaFlashMLABackend.get_kv_cache_usharp_alignment(block_size)
|
||||
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||||
# return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
logger.debug(
|
||||
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||||
)
|
||||
# TODO, shared kv only used in deepseek
|
||||
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||||
max_h_limit = 2048
|
||||
return max_h_limit // block_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupaFlashMLAMetadata:
|
||||
# class SupaFlashMLAMetadata(FlashMLAMetadata):
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# BIREN Attention Params
|
||||
seq_start_loc: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
max_decode_seq_len: int
|
||||
do_cache: bool # when use attentionsplit, do cache = False
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_actual_reqs: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class SupaFlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
FlashMLAMetadata)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
self.aot_schedule = False
|
||||
logger.warning(
|
||||
"AOT Schedule is disabled when using SUPAFlashAttention.")
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
supports_spec_as_decode = True
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: SUPACommonAttentionMetadata,
|
||||
fast_build: bool = False):
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
num_actual_reqs = common_attn_metadata.num_actual_reqs
|
||||
|
||||
aot_schedule = self.aot_schedule and not fast_build
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True)
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
# For the AOT scheduler we need the sliding window value to be
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
aot_schedule = False
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.aot_schedule:
|
||||
raise NotImplementedError(
|
||||
'aot schedule not support in SUPA attention')
|
||||
return None
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=suffix_kv_lens,
|
||||
max_seq_len=max_seq_len -
|
||||
common_prefix_len,
|
||||
causal=True)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
|
||||
if common_attn_metadata.seq_start_loc is None:
|
||||
if len(seq_lens) > 8:
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
seq_start_loc = torch.tensor(
|
||||
[0] + list(itertools.accumulate(seq_lens_cpu)),
|
||||
device=query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
seq_start_loc = torch.tensor(
|
||||
[0] + list(itertools.accumulate(seq_lens)),
|
||||
device=query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
seq_start_loc = common_attn_metadata.seq_start_loc
|
||||
|
||||
if common_attn_metadata.context_lens is None:
|
||||
context_lens = seq_lens - (query_start_loc[1:] -
|
||||
query_start_loc[:-1])
|
||||
else:
|
||||
context_lens = common_attn_metadata.context_lens
|
||||
|
||||
if common_attn_metadata.max_decode_seq_len is None:
|
||||
max_decode_seq_len = max_decode_seq_len = int(
|
||||
seq_lens.max().item())
|
||||
else:
|
||||
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
|
||||
|
||||
attn_metadata = SupaFlashMLAMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
# Biren Attention Params
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
do_cache=True,
|
||||
num_actual_reqs=num_actual_reqs)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: SUPACommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return False
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# class SupaFlashMLAImpl(FlashMLAImpl):
|
||||
class SupaFlashMLAImpl(MLACommonImpl[SupaFlashMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# # q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# # attention backend perspective we rely on the layer to pass in the
|
||||
# # correct matrix
|
||||
q_proj: ColumnParallelLinear, # q_b_proj
|
||||
# kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
kv_a_proj_with_mqa: ReplicatedLinear,
|
||||
kv_a_layernorm: Any,
|
||||
q_a_proj: ReplicatedLinear,
|
||||
q_a_layernorm: Any,
|
||||
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, q_lora_rank,
|
||||
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim,
|
||||
qk_head_dim, v_head_dim, kv_b_proj, **mla_args)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.kv_a_proj_with_mqa = kv_a_proj_with_mqa
|
||||
self.kv_a_layernorm = kv_a_layernorm
|
||||
self.q_a_layernorm = q_a_layernorm
|
||||
self.q_a_proj = q_a_proj
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
cur_device = torch.supa.current_device()
|
||||
self.spc_num = torch_br.supa.get_device_properties(
|
||||
cur_device).max_compute_units
|
||||
|
||||
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and self.tp_size == 8 and self.spc_num == 16:
|
||||
# Initialize the p2p info
|
||||
torch.supa.init_p2p_remote_id(cur_device)
|
||||
|
||||
assert self.q_lora_rank is not None
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"SUPAFlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"SUPAFlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"SUPAFlashMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
# handle deepseek_v3 weight
|
||||
w_q_a = get_and_maybe_dequant_weights(self.q_a_proj).T
|
||||
w_kv_a = get_and_maybe_dequant_weights(self.kv_a_proj_with_mqa).T
|
||||
w_qkv_a = torch.cat([w_q_a, w_kv_a], dim=-1)
|
||||
# w_qkv_a must make two copies in br166
|
||||
align_size = 32
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
if die_spc_num > 16:
|
||||
w_qkv_a = torch.cat([w_qkv_a, w_qkv_a], dim=-1)
|
||||
self.w_qkv_a = _convert_to_numa_tensor(w_qkv_a, align_size,
|
||||
"colmajor", w_qkv_a.dtype)
|
||||
|
||||
w_kv_b = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
w_k_b, w_v_b = w_kv_b.reshape(
|
||||
self.kv_lora_rank, -1,
|
||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
w_k_b = w_k_b.permute(1, 2, 0).contiguous()
|
||||
w_v_b = w_v_b.permute(1, 0, 2).contiguous()
|
||||
|
||||
w_o = get_and_maybe_dequant_weights(self.o_proj.to(w_v_b.device)).T
|
||||
hidden_dim = w_o.shape[-1]
|
||||
w_o = w_o.reshape(-1, self.v_head_dim, hidden_dim)
|
||||
w_vo = torch.bmm(w_v_b, w_o).reshape(-1, hidden_dim)
|
||||
self.w_vo = _convert_to_numa_tensor(w_vo,
|
||||
align_size,
|
||||
"colmajor",
|
||||
w_qkv_a.dtype,
|
||||
parallel_type="row_parallel")
|
||||
|
||||
# replace q_b_proj as q_proj
|
||||
w_q_b = get_and_maybe_dequant_weights(self.q_proj).T
|
||||
w_q_b_nope, w_q_b_rope = w_q_b.reshape(
|
||||
self.q_lora_rank, -1, self.qk_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
w_q_b_nope = w_q_b_nope.permute(1, 0, 2).contiguous()
|
||||
w_q_b_rope = w_q_b_rope.reshape(self.q_lora_rank, -1)
|
||||
|
||||
w_qk_b_nope = torch.bmm(w_q_b_nope, w_k_b).permute(
|
||||
1, 0, 2).contiguous().reshape(self.q_lora_rank, -1)
|
||||
# w_qk_b_nope w_q_b_rope is independent head, separate like QKVParallelLinear
|
||||
if die_spc_num > 16:
|
||||
qk_b_nope0, qk_b_nope1 = torch.chunk(w_qk_b_nope, 2, dim=-1)
|
||||
qk_b_rope0, qk_b_rope1 = torch.chunk(w_q_b_rope, 2, dim=-1)
|
||||
w_qk_b = torch.cat(
|
||||
[qk_b_nope0, qk_b_rope0, qk_b_nope1, qk_b_rope1], dim=-1)
|
||||
else:
|
||||
w_qk_b = torch.cat([w_qk_b_nope, w_q_b_rope], dim=-1)
|
||||
self.w_qk_b = _convert_to_numa_tensor(w_qk_b, align_size,
|
||||
"colmajor", w_qkv_a.dtype)
|
||||
|
||||
self.q_a_proj.weight = None
|
||||
self.kv_a_proj_with_mqa.weight = None
|
||||
self.q_proj.weight = None
|
||||
self.kv_b_proj.weight = None
|
||||
self.o_proj.weight = None
|
||||
|
||||
if self.kv_a_layernorm.weight.dtype != torch.float32:
|
||||
self.kv_a_layernorm.weight.data = self.kv_a_layernorm.weight.to(
|
||||
torch.float32)
|
||||
if self.q_a_layernorm.weight.dtype != torch.float32:
|
||||
self.q_a_layernorm.weight.data = self.q_a_layernorm.weight.to(
|
||||
torch.float32)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
torch.supa.empty_cache()
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
positions: torch.Tensor, # reuse k_c_normed as position
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: SupaFlashMLAMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SPDA and PagedAttention.
|
||||
|
||||
Args:
|
||||
hidden_states: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [1, num_blocks, block_size * num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is None, "Output tensor should not provided."
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
|
||||
self, "grandparent_pid"):
|
||||
self.grandparent_pid = get_grandparent_pid()
|
||||
|
||||
# profile and warm up mla attention kernel
|
||||
if attn_metadata is None:
|
||||
return hidden_states
|
||||
|
||||
# handle deepseek_v3 mla
|
||||
if hidden_states.shape[1] <= 512:
|
||||
query, key = torch_br.supa_mla_prefix_infer_v2(
|
||||
hidden_states, self.w_qkv_a, self.w_qk_b,
|
||||
self.q_a_layernorm.weight, self.kv_a_layernorm.weight,
|
||||
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
|
||||
positions, kv_cache, attn_metadata.slot_mapping,
|
||||
self.num_heads, self.qk_head_dim, self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
|
||||
self.q_lora_rank, self.kv_a_layernorm.variance_epsilon)
|
||||
else:
|
||||
query, key = torch_br.supa_mla_prefix_infer_v3(
|
||||
hidden_states, self.w_qkv_a, self.w_qk_b,
|
||||
self.q_a_layernorm.weight, self.kv_a_layernorm.weight,
|
||||
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
|
||||
positions, kv_cache, attn_metadata.slot_mapping,
|
||||
self.num_heads, self.qk_head_dim, self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
|
||||
self.q_lora_rank, self.kv_a_layernorm.variance_epsilon)
|
||||
|
||||
if query.shape[0] == 1:
|
||||
output = torch.empty_like(query)
|
||||
else:
|
||||
output = torch_br._empty_ut_only(
|
||||
[1, query.shape[1], query.shape[0] * self.kv_lora_rank],
|
||||
device=query.device,
|
||||
dtype=query.dtype,
|
||||
tensor_type="colmajor",
|
||||
axis=2,
|
||||
sbp="SB" if envs.VLLM_BR_DEVICE_SPC_NUM > 16 else None)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
#decoder_qloc = attn_metadata.query_start_loc[:attn_metadata.num_decodes + 1].cpu()
|
||||
#if decoder_qloc.shape[0] > 1:
|
||||
# assert torch.all(torch.diff(decoder_qloc) == 1), f"Must ensure that it is an increasing queue with a step of 1 !\nq_loc:{attn_metadata.query_start_loc}"
|
||||
#print("num_prefill_tokens:", num_prefill_tokens)
|
||||
if num_prefill_tokens > 0:
|
||||
assert len(query.shape) == 3
|
||||
output = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
self.head_size,
|
||||
alibi_slopes=None,
|
||||
softmax_scale=self.scale,
|
||||
v_head_size=self.kv_lora_rank,
|
||||
num_reqs=attn_metadata.num_actual_reqs,
|
||||
)
|
||||
else:
|
||||
assert len(query.shape) == 3 and attn_metadata.num_prefills == 0
|
||||
output = torch_br.supa_attention_decoder_infer_v2( # type: ignore
|
||||
query, # type: ignore
|
||||
kv_cache,
|
||||
attn_metadata.block_table,
|
||||
attn_metadata.seq_lens,
|
||||
attn_metadata.max_decode_seq_len,
|
||||
self.head_size,
|
||||
attn_metadata.num_prefills,
|
||||
alibi_slopes=None,
|
||||
softmax_scale=self.scale,
|
||||
v_head_size=self.kv_lora_rank,
|
||||
)
|
||||
|
||||
# now linear+allreduce only support M <= 512 and tp_size == 4 | 8 and spc_num == 16
|
||||
seq_len = hidden_states.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
fused_comm = (envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN
|
||||
and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
|
||||
if fused_comm:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
o_proj_out = torch_br.supa_fused_linear_allreduce_opt(
|
||||
output, self.w_vo, hidden_states.shape[-1], tp_rank, tp_size,
|
||||
global_rank, 0)
|
||||
else:
|
||||
# do o_proj
|
||||
output_parallel = torch_br.br_fused_mlp_infer(
|
||||
output, [self.w_vo], output_w=hidden_states.shape[-1])
|
||||
if self.tp_size > 1:
|
||||
o_proj_out = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
o_proj_out = output_parallel
|
||||
|
||||
return o_proj_out
|
||||
450
vllm_br/v1/attention/backends/mla/flashmla_sparse.py
Normal file
450
vllm_br/v1/attention/backends/mla/flashmla_sparse.py
Normal file
@@ -0,0 +1,450 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata
|
||||
from vllm.attention.ops.flashmla import get_mla_metadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend, FlashMLASparseImpl, FlashMLASparseMetadata,
|
||||
FlashMLASparseMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_NO_DEFAULT = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupaFlashMLASparseMetadata(FlashMLASparseMetadata):
|
||||
# BIREN Attention Params
|
||||
seq_start_loc: torch.Tensor = _NO_DEFAULT
|
||||
context_lens: torch.Tensor = _NO_DEFAULT
|
||||
max_decode_seq_len: int = -1
|
||||
num_prefills: int = -1
|
||||
num_decodes: int = -1
|
||||
num_prefill_tokens: int = -1
|
||||
num_decode_tokens: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.seq_start_loc is _NO_DEFAULT or self.context_lens is _NO_DEFAULT or \
|
||||
self.max_decode_seq_len == -1 or self.num_prefills == -1 or \
|
||||
self.num_decodes == -1 or self.num_prefill_tokens == -1 or \
|
||||
self.num_decode_tokens == -1:
|
||||
raise TypeError("__init__ missing required argument")
|
||||
|
||||
|
||||
class SupaFlashMLASparseMetadataBuilder(FlashMLASparseMetadataBuilder):
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
layer_names=layer_names,
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
self.vllm_config = vllm_config
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config else 0)
|
||||
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
|
||||
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""On SUPA, we want prefill at front and decode at back.
|
||||
"""
|
||||
# TODO update doc
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens - num_spec_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
# TODO update doc
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
# for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# # If the decode is at the "back" of the batch, i, we can swap it
|
||||
# # with the prefill closest to the front of the batch
|
||||
# decode_idx = decodes[num_decodes - i]
|
||||
# if decode_idx < num_decodes:
|
||||
# break
|
||||
|
||||
# input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
# modified_batch = True
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
prefills_idx = prefills[num_prefills - i]
|
||||
if prefills_idx < num_prefills:
|
||||
break
|
||||
|
||||
input_batch.swap_states(decodes[i - 1], prefills_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> SupaFlashMLASparseMetadata:
|
||||
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
|
||||
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata = None
|
||||
if self.use_fp8_kv_cache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Copy to persistent buffer for full-CG support
|
||||
tile_scheduler_metadata_buffer = \
|
||||
self.tile_scheduler_metadata_buffer[:num_sm_parts]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
self.num_splits_buffer.copy_(num_splits)
|
||||
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=self.num_splits_buffer,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=self.max_model_len_tensor,
|
||||
dummy_block_table=self.dummy_block_table)
|
||||
|
||||
# Add biren attention params
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
if common_attn_metadata.seq_start_loc is None:
|
||||
if len(seq_lens) > 8:
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
seq_start_loc = torch.tensor(
|
||||
[0] + list(itertools.accumulate(seq_lens_cpu)),
|
||||
device=query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
seq_start_loc = torch.tensor(
|
||||
[0] + list(itertools.accumulate(seq_lens)),
|
||||
device=query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
seq_start_loc = common_attn_metadata.seq_start_loc
|
||||
|
||||
if common_attn_metadata.context_lens is None:
|
||||
context_lens = seq_lens - (query_start_loc[1:] -
|
||||
query_start_loc[:-1])
|
||||
else:
|
||||
context_lens = common_attn_metadata.context_lens
|
||||
|
||||
if common_attn_metadata.max_decode_seq_len is None:
|
||||
max_decode_seq_len = max_decode_seq_len = int(
|
||||
seq_lens.max().item())
|
||||
else:
|
||||
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
metadata = SupaFlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class SupaFlashMLASparseBackend(FlashMLASparseBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SUPA_FLASHMLA_SPARSE_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[AttentionMetadata]:
|
||||
return SupaFlashMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["SupaFlashMLASparseMetadataBuilder"]:
|
||||
return SupaFlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["SupaFlashMLASparseImpl"]:
|
||||
return SupaFlashMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
th_gran = SupaFlashMLASparseBackend.get_kv_cache_usharp_alignment(
|
||||
block_size)
|
||||
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||||
logger.debug(
|
||||
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||||
)
|
||||
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||||
max_h_limit = 2048
|
||||
return max_h_limit // block_size
|
||||
|
||||
|
||||
class SupaFlashMLASparseImpl(FlashMLASparseImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: Optional[torch.Tensor] = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, topk_indice_buffer,
|
||||
indexer, **mla_args)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: SupaFlashMLASparseMetadata) -> torch.Tensor:
|
||||
bsz = 1
|
||||
seq_len_q, num_heads, _ = q.shape
|
||||
|
||||
# topk_indices = topk_indices.unsqueeze(0)
|
||||
index_mask = torch.full((bsz, seq_len_q, seq_len_q),
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=q.device)
|
||||
# .scatter_(-1, valid_mask.to(torch.int64), 0).to(torch.int32).supa()
|
||||
|
||||
for idx_bsz in range(bsz):
|
||||
for idx_q in range(seq_len_q):
|
||||
for idx_k in range(topk_indices.shape[-1]):
|
||||
target_idx = topk_indices[idx_q][idx_k]
|
||||
if target_idx >= 0 and target_idx < seq_len_q:
|
||||
index_mask[idx_bsz][idx_q][topk_indices[idx_q]
|
||||
[idx_k]] = 0
|
||||
|
||||
query = q.transpose(0,
|
||||
1).contiguous() # [num_heads, seq_len, head_dim]
|
||||
# output is always [1, seq_len, num_heads * head_dim] however query;s shape is
|
||||
output = torch_br.supa_flash_attn_cache_infer(
|
||||
query,
|
||||
kv_c_and_k_pe_cache[:
|
||||
1], # [1, num_blocks, block_szie,self.head_size]
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.max_seq_len,
|
||||
self.head_size,
|
||||
softmax_scale=self.softmax_scale,
|
||||
v_head_size=self.kv_lora_rank,
|
||||
mask=index_mask)
|
||||
|
||||
output = output.reshape(seq_len_q, num_heads,
|
||||
self.kv_lora_rank).contiguous()
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: SupaFlashMLASparseMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLACommonImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
# TODO: handle index / kv_cache correctly
|
||||
# topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
# attn_metadata.req_id_per_token,
|
||||
# attn_metadata.block_table,
|
||||
# topk_indices,
|
||||
# BLOCK_SIZE=attn_metadata.block_size,
|
||||
# NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
# )
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
_, num_blocks, block_size, head_size = kv_cache.shape
|
||||
k_pe_tmp = k_pe.squeeze(1).unsqueeze(0)
|
||||
key_supa = torch.cat([k_c_normed, k_pe_tmp], dim=2)
|
||||
torch_br.supa_kvcache_store_infer_v2(kv_cache, key_supa, key_supa,
|
||||
attn_metadata.slot_mapping,
|
||||
head_size)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
||||
attn_metadata)
|
||||
else:
|
||||
raise RuntimeError("Not support fp8 on br.")
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
140
vllm_br/v1/attention/backends/mla/indexer.py
Normal file
140
vllm_br/v1/attention/backends/mla/indexer.py
Normal file
@@ -0,0 +1,140 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerBackend, DeepSeekV32IndexerDecodeMetadata,
|
||||
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadataBuilder,
|
||||
DeepseekV32IndexerPrefillMetadata, split_prefill_chunks)
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SupaDeepseekV32IndexerBackend(DeepseekV32IndexerBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["SupaDeepseekV32IndexerMetadataBuilder"]:
|
||||
return SupaDeepseekV32IndexerMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
th_gran = SupaDeepseekV32IndexerBackend.get_kv_cache_usharp_alignment(
|
||||
block_size)
|
||||
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||||
logger.debug(
|
||||
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||||
)
|
||||
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||||
max_h_limit = 2048
|
||||
return max_h_limit // block_size
|
||||
|
||||
|
||||
class SupaDeepseekV32IndexerMetadataBuilder(DeepseekV32IndexerMetadataBuilder):
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start, reqs_end, query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks, )
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes])
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max()
|
||||
> decode_lens_cpu.min()).item()
|
||||
|
||||
# self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
# seq_lens, self.kv_cache_spec.block_size, self.num_sms)
|
||||
self.scheduler_metadata_buffer = None
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.
|
||||
block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
47
vllm_br/v1/attention/backends/utils.py
Normal file
47
vllm_br/v1/attention/backends/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class SUPACommonAttentionMetadata(CommonAttentionMetadata):
|
||||
"""
|
||||
Attention metadata attributes that can be shared by layers in different KV
|
||||
cache groups and thus having different block table.
|
||||
"""
|
||||
query_start_loc: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
num_actual_reqs: torch.Tensor | None = None
|
||||
"""(1,), numble of actual request in the batch"""
|
||||
supagraph_runtime_mode: SUPAGraphMode | None = None
|
||||
context_lens: torch.Tensor | None = None
|
||||
"""(batch_size,), the length of each request including computed tokens only"""
|
||||
max_decode_seq_len: int | None = None
|
||||
"""The maximum length of the decoded sequence in the batch."""
|
||||
seq_start_loc: torch.Tensor | None = None
|
||||
"""(batch_size + 1,), the start location of each request in sequence Tensor.
|
||||
This is used to compute the sequence length of each request.
|
||||
If not provided, it will be computed from seq_lens."""
|
||||
Reference in New Issue
Block a user