This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

View File

@@ -0,0 +1,503 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
logger = init_logger(__name__)
_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S390X)
class CPUAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "CPU_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder,
encoder-only and encoder-decoder attention."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl
@staticmethod
def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]:
return CPUAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return 2, num_blocks, num_kv_heads, block_size, head_size
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class CPUAttentionMetadata:
isa: str
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
scheduler_metadata: torch.Tensor | None
causal: bool = True
# can be removed after deprecate sdpa
use_sdpa_prefill: bool = False
num_decode_tokens: int = 0
sdpa_attn_masks: list[torch.Tensor | None] | None = None
sdpa_start_loc: torch.Tensor | None = None
class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.use_sdpa_prefill = False
reorder_batch_threshold = None
if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH:
# in this case, decode seqs are reordered to the front of prefill seqs
# to split decode and prefill. Then use SDPA for prefill and
# cpu_attention_with_kv_cache for decode
reorder_batch_threshold = 1
self.use_sdpa_prefill = True
self._init_reorder_batch_threshold(reorder_batch_threshold, False)
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
parallel_config = vllm_config.parallel_config
self.num_kv_heads = vllm_config.model_config.get_num_kv_heads(parallel_config)
self.num_heads = vllm_config.model_config.get_num_attention_heads(
parallel_config
)
self.head_dim = kv_cache_spec.head_size
self.dtype = vllm_config.model_config.dtype
self.window_size = getattr(kv_cache_spec, "sliding_window", -1)
if self.window_size is None:
self.window_size = -1
self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim)
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CPUAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = False if self.is_cross_attention else common_attn_metadata.causal
sdpa_start_loc = query_start_loc
num_decode_tokens = 0
if self.use_sdpa_prefill and causal:
# Decoder, need reorder and truncate
assert self.reorder_batch_threshold
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True,
)
)
num_reqs = num_decodes
sdpa_start_loc = sdpa_start_loc[num_decodes:] - num_decode_tokens
seq_lens = seq_lens[:num_decodes]
query_start_loc = query_start_loc[: num_decodes + 1]
block_table_tensor = block_table_tensor[:num_decodes]
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
num_reqs=num_reqs,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
seq_lens=seq_lens,
dtype=self.dtype,
query_start_loc=query_start_loc,
causal=causal,
sliding_window_size=self.window_size,
isa=self.isa,
enable_kv_split=True,
)
attn_metadata = CPUAttentionMetadata(
isa=self.isa,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
scheduler_metadata=sheduler_metadata,
causal=causal,
use_sdpa_prefill=self.use_sdpa_prefill,
num_decode_tokens=num_decode_tokens,
sdpa_start_loc=sdpa_start_loc,
)
return attn_metadata
class CPUAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
if logits_soft_cap is not None and attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
logger.warning_once(
"CPU_ATTN does not support logits softcap for"
" ENCODER and ENCODER_ONLY, outputs may be slightly off"
)
if logits_soft_cap is None:
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type == AttentionType.ENCODER_ONLY:
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN")
self.attn_type = attn_type
self.sinks = sinks
if self.sinks is not None:
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer"
)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: CPUAttentionMetadata | None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass for CPU attention backend.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for CPUAttentionBackendImpl"
)
# For warming-up
if attn_metadata is None:
return output
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
return self._run_sdpa_forward(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
self.attn_type,
)
# For decoder and cross-attention, use KV cache, size are
# [num_blocks, num_kv_heads, block_size, head_size]
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
ops.cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.isa,
)
if attn_metadata.use_sdpa_prefill:
assert self.sinks is None, "Attention sink is unsupported in SDPA prefill"
num_decode_tokens = attn_metadata.num_decode_tokens
self._run_sdpa_forward(
query[num_decode_tokens:num_actual_tokens],
key[num_decode_tokens:num_actual_tokens],
value[num_decode_tokens:num_actual_tokens],
output[num_decode_tokens:num_actual_tokens],
attn_metadata,
self.attn_type,
)
num_actual_tokens = num_decode_tokens
if num_actual_tokens > 0:
ops.cpu_attention_with_kv_cache(
query=query[:num_actual_tokens],
key_cache=key_cache,
value_cache=value_cache,
output=output[:num_actual_tokens], # type: ignore
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes, # type: ignore
sliding_window=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=attn_metadata.scheduler_metadata,
s_aux=self.sinks,
)
return output
def _run_sdpa_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: CPUAttentionMetadata,
attn_type: str,
) -> torch.Tensor:
attn_masks = attn_metadata.sdpa_attn_masks
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.sdpa_start_loc,
)
elif self.sliding_window[0] != -1 or self.sliding_window[1] != -1:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.sdpa_start_loc,
self.sliding_window[0],
self.sliding_window[1],
query.dtype,
)
else:
attn_masks = [None] * (attn_metadata.sdpa_start_loc.size(0) - 1) # type: ignore
attn_metadata.sdpa_attn_masks = attn_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
causal_attn = attn_type == AttentionType.DECODER
sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
for i in range(len(attn_masks)):
mask = attn_masks[i]
start_q = sdpa_start_loc[i]
end_q = sdpa_start_loc[i + 1]
sub_out = (
torch.nn.functional.scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_q:end_q, :],
value[None, :, start_q:end_q, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and mask is None,
scale=self.scale,
enable_gqa=self.num_heads > self.num_kv_heads,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = sub_out
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
sdpa_start_loc: torch.Tensor,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
seq_num = sdpa_start_loc.size(0) - 1
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
for i in range(seq_num):
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
bias = torch.arange(seq_len, dtype=dtype) # type: ignore
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = (
torch.empty((1, seq_len, seq_len), dtype=bias.dtype) # type: ignore
.fill_(-torch.inf)
.triu_(diagonal=1)
)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
sdpa_start_loc: torch.Tensor,
left_window_size: int,
right_window_size: int,
dtype: torch.dtype,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
seq_num = sdpa_start_loc.size(0) - 1
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
for i in range(seq_num):
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
mask = torch.full( # type: ignore
(1, seq_len, seq_len), # type: ignore
fill_value=1,
dtype=dtype,
)
if right_window_size != -1:
mask = torch.tril(mask, diagonal=right_window_size)
if left_window_size != -1:
mask = torch.triu(mask, diagonal=-left_window_size)
mask = torch.log(mask)
attn_biases.append(mask)
return attn_biases
def _get_attn_isa(
dtype: torch.dtype, block_size: int, head_size: int | None = None
) -> str:
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
return "vec16"
supports_amx = torch._C._cpu._is_amx_tile_supported()
supports_arm = current_platform.get_cpu_architecture() == CpuArchEnum.ARM
supports_vxe = current_platform.get_cpu_architecture() == CpuArchEnum.S390X
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx"
elif block_size % 32 == 0:
if supports_arm:
# support ARM NEON FMLA and BFMMLA (bf16) for block size 32
return "neon"
elif supports_vxe:
return "vxe"
else:
return "vec"
else:
return "vec16"

View File

@@ -0,0 +1,177 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
# Track whether upstream flash-attn is available on ROCm.
# Set during module initialization and never modified afterwards.
# This module-level flag avoids repeated import attempts and ensures
# consistent behavior (similar to IS_AITER_FOUND in _aiter_ops.py).
_ROCM_FLASH_ATTN_AVAILABLE = False
if current_platform.is_cuda():
from vllm._custom_ops import reshape_and_cache_flash
# from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
# flash_attn_varlen_func,
# get_scheduler_metadata,
# )
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
elif current_platform.is_xpu():
from vllm import _custom_ops as ops
from vllm._xpu_ops import xpu_ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[assignment]
get_scheduler_metadata = xpu_ops.get_scheduler_metadata # type: ignore[assignment]
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
# Mark that upstream flash-attn is available on ROCm
_ROCM_FLASH_ATTN_AVAILABLE = True
except ImportError:
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
raise ImportError(
"ROCm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
)
# ROCm doesn't use scheduler metadata (FA3 feature), provide stub
def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
return None
# ROCm uses the C++ custom op for reshape_and_cache
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
return 3
if current_platform.is_xpu():
return 2
if current_platform.is_rocm():
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
return None
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason,
is_fa_version_supported,
)
device_capability = current_platform.get_device_capability()
assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.flash_attn_version is not None
):
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2."
)
fa_version = 2
if requires_alibi and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
fa_version,
fa_version_unsupported_reason(fa_version),
)
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
def flash_attn_supports_fp8() -> bool:
return (
get_flash_attn_version() == 3
and current_platform.is_device_capability_family(90)
)
def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu():
return True
else:
return get_flash_attn_version() == 3
def flash_attn_supports_mla():
from vllm.platforms import current_platform
if current_platform.is_cuda():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
is_fa_version_supported,
)
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
except (ImportError, AssertionError):
pass
return False
def is_flash_attn_varlen_func_available() -> bool:
"""Check if flash_attn_varlen_func is available.
This function determines whether the flash_attn_varlen_func imported at module
level is a working implementation or a stub.
Platform-specific sources:
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
- XPU: xpu_ops.flash_attn_varlen_func
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
which uses rocm_aiter_ops.flash_attn_varlen_func. The condition to use AITER is
handled separately via _aiter_ops.is_aiter_found_and_supported().
Returns:
bool: True if a working flash_attn_varlen_func implementation is available.
"""
if current_platform.is_cuda() or current_platform.is_xpu():
# CUDA and XPU always have flash_attn_varlen_func available
return True
if current_platform.is_rocm():
# Use the flag set during module import to check if
# upstream flash-attn was successfully imported
return _ROCM_FLASH_ATTN_AVAILABLE
return False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,277 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
import torch
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
if is_flash_attn_varlen_func_available():
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from .flash_attn import (
FlashAttentionBackend,
FlashAttentionImpl,
FlashAttentionMetadata,
cascade_attention,
)
logger = init_logger(__name__)
class FlashAttentionDiffKVBackend(FlashAttentionBackend):
# Default to 128 for this backend
head_size_v: int = 128
@classmethod
def set_head_size_v(cls, head_size_v: int) -> None:
cls.head_size_v = head_size_v
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_DIFFKV"
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionDiffKVImpl
# Do not modify the interface of get_kv_cache_shape,
# but consider head_size_v when returning result.
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (
num_blocks,
block_size,
num_kv_heads,
head_size + FlashAttentionDiffKVBackend.head_size_v,
)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size,
# num_kv_heads, head_size + head_size_v)
return (1, 0, 2, 3, 4)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers,
# block_size, head_size + head_size_v)
return (1, 3, 0, 2, 4)
elif cache_layout == "HND":
stride_order = (0, 2, 1, 3)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
class FlashAttentionDiffKVImpl(FlashAttentionImpl):
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size_v]
kv_cache: shape =
[num_blocks, block_size, num_kv_heads, head_size + head_size_v]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size_v]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for FlashAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
attn_type = self.attn_type
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
# For decoder and cross-attention, use KV cache as before
# Different head_size for K and V
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
# kv_cache update for different head_size K and V
triton_reshape_and_cache_flash_diffkv(
key,
value,
kv_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype
)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
else:
sliding_window_size = (
list(self.sliding_window)
if self.sliding_window is not None
else None
)
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
return output
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
max_num_splits=attn_metadata.max_num_splits,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
s_aux=self.sinks,
)
return output

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,430 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
import torch
from vllm.config import VllmConfig
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "GDN_ATTN"
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: torch.Tensor | None = None
spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,]
non_spec_query_start_loc: torch.Tensor | None = (
None # shape: [batch - num_spec_decodes + 1,]
)
spec_state_indices_tensor: torch.Tensor | None = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: torch.Tensor | None = (
None # shape: [batch - num_spec_decodes,]
)
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
spec_token_indx: torch.Tensor | None = None
non_spec_token_indx: torch.Tensor | None = None
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
assert self.speculative_config.num_speculative_tokens is not None
self.num_spec: int = self.speculative_config.num_speculative_tokens
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.decode_cudagraph_max_bs = (
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1)
)
if self.compilation_config.max_cudagraph_capture_size is not None:
self.decode_cudagraph_max_bs = min(
self.decode_cudagraph_max_bs,
self.compilation_config.max_cudagraph_capture_size,
)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)
self.spec_token_indx = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.non_spec_token_indx = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build( # type: ignore[override]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: torch.Tensor | None = None,
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
query_start_loc_cpu = m.query_start_loc_cpu
context_lens_tensor = m.compute_num_computed_tokens()
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
block_table_tensor = mamba_get_block_table_tensor(
m.block_table_tensor,
m.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
spec_sequence_masks_cpu: torch.Tensor | None = None
if (
not self.use_spec_decode
or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
.sum()
.item()
== 0
):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks_cpu.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
spec_sequence_masks_cpu = None
else:
spec_sequence_masks = spec_sequence_masks_cpu.to(
query_start_loc.device, non_blocking=True
)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1)
)
num_spec_decode_tokens = 0
spec_token_indx = None
non_spec_token_indx = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
non_spec_query_start_loc_cpu = query_start_loc_cpu
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
assert spec_sequence_masks_cpu is not None
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
# Use CPU tensors to avoid CPU-GPU sync
non_spec_query_lens_cpu = query_lens_cpu[~spec_sequence_masks_cpu]
num_decodes = (non_spec_query_lens_cpu == 1).sum().item()
# Exclude zero-length padded sequences from prefill count.
num_zero_len = (non_spec_query_lens_cpu == 0).sum().item()
num_prefills = non_spec_query_lens_cpu.size(0) - num_decodes - num_zero_len
num_decode_tokens = num_decodes
num_prefill_tokens = (
non_spec_query_lens_cpu.sum().item() - num_decode_tokens
)
num_spec_decode_tokens = (
query_lens_cpu.sum().item() - num_prefill_tokens - num_decode_tokens
)
if num_prefills == 0 and num_decodes == 0:
spec_token_size = min(
num_spec_decodes * (self.num_spec + 1),
query_start_loc_cpu[-1].item(),
)
spec_token_indx = torch.arange(
spec_token_size,
dtype=torch.int32,
device=query_start_loc.device,
)
non_spec_token_indx = torch.empty(
0, dtype=torch.int32, device=query_start_loc.device
)
# Filter by spec_sequence_masks to exclude padded sequences
spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = None
# Padded sequences are always at the back, so the first
# num_spec_decodes + 1 entries of query_start_loc already
# contain the correct cumulative token counts.
spec_query_start_loc = query_start_loc[: num_spec_decodes + 1]
non_spec_query_start_loc = None
non_spec_query_start_loc_cpu = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
index = torch.argsort(spec_token_masks, stable=True)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]
spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = block_table_tensor[
~spec_sequence_masks, 0
]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
)
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:],
)
non_spec_query_start_loc_cpu = torch.zeros(
query_lens_cpu.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
)
torch.cumsum(
query_lens_cpu[~spec_sequence_masks_cpu],
dim=0,
out=non_spec_query_start_loc_cpu[1:],
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
assert non_spec_query_start_loc_cpu is not None
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(
non_spec_query_start_loc_cpu,
device=query_start_loc.device,
)
)
else:
has_initial_state = None
# Function code counted on either presency non-spec decode or spec decode,
# but not both.
assert not (num_decodes > 0 and num_spec_decodes > 0), (
f"num_decodes: {num_decodes}, num_spec_decodes: {num_spec_decodes}"
)
# Prepare tensors for cudagraph
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
batch_size = m.num_actual_tokens
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks[:num_spec_decodes], non_blocking=True
)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert non_spec_token_indx is not None and spec_token_indx is not None
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
non_spec_token_indx, non_blocking=True
)
non_spec_token_indx = self.non_spec_token_indx[
: non_spec_token_indx.size(0)
]
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
spec_token_indx, non_blocking=True
)
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True
)
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True
)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs
):
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True
)
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
:batch_size
]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True
)
non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index]
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=m.num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
spec_sequence_masks=spec_sequence_masks,
spec_token_indx=spec_token_indx,
non_spec_token_indx=non_spec_token_indx,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (
m.num_reqs <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
), (
f"GDN only supports decode-only full CUDAGraph capture. "
f"Make sure batch size ({m.num_reqs}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
f"and number of tokens ({m.num_actual_tokens}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
)
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)

View File

@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.config import VllmConfig
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class LinearAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "LINEAR_ATTN"
@staticmethod
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
return LinearAttentionMetadataBuilder
@dataclass
class LinearAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> LinearAttentionMetadata:
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
attn_metadata = LinearAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
class Mamba1AttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "MAMBA1_ATTN"
@staticmethod
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
return Mamba1AttentionMetadataBuilder
@dataclass
class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):
pass
class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
):
metadata_cls = Mamba1AttentionMetadata

View File

@@ -0,0 +1,267 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from dataclasses import dataclass, replace
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.kv_cache_interface import AttentionSpec
def compute_varlen_chunk_metadata(
query_start_loc: torch.Tensor,
chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
and a physical `chunk_size`, returns three tensors on the same device:
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
logical-chunk lengths (each logical chunk never crosses a sequence or
physical-chunk boundary).
- last_chunk_indices: (B,) int32 index of the last logical chunk
for each sequence (=-1 for empty sequences).
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
chunk in order.
This is intentionally lightweight and CPU-side; it mirrors the metadata
produced by the V1 Mamba2 meta-data builder and is exported so tests
(and other callers) can avoid duplicating the logic.
"""
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
device = query_start_loc.device
qsl64 = query_start_loc.to(torch.int64)
starts = qsl64[:-1].tolist()
ends = qsl64[1:].tolist()
total = int(qsl64[-1].item())
chunk_lens: list[int] = []
seq_idx_chunks: list[int] = []
last_chunk_indices: list[int] = [-1] * len(starts)
for b, (s, e) in enumerate(zip(starts, ends)):
if e <= s:
# empty sequence
continue
pos = s
while pos < e:
# split at both sequence boundaries and physical chunk boundaries
room = chunk_size - (pos % chunk_size)
take = min(room, e - pos)
chunk_lens.append(int(take))
seq_idx_chunks.append(b)
last_chunk_indices[b] = len(chunk_lens) - 1
pos += take
# Exclusive prefix sum over logical-chunk lengths
if chunk_lens:
cu_chunk_seqlens = torch.tensor(
[0] + list(itertools.accumulate(chunk_lens)),
device=device,
dtype=torch.int32,
)
# Final boundary must equal total tokens
assert int(cu_chunk_seqlens[-1].item()) == total
else:
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
last_chunk_indices_t = (
torch.tensor(last_chunk_indices, device=device, dtype=torch.int32)
if len(starts) > 0
else torch.empty((0,), device=device, dtype=torch.int32)
)
seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32)
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
class Mamba2AttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "MAMBA2_ATTN"
@staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
return Mamba2AttentionMetadataBuilder
@dataclass
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
prep_initial_states: bool = False
chunk_size: int = 0
# Chunk-related metadata (only for prefill)
seq_idx_p: torch.Tensor | None = None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None = None
class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
):
metadata_cls = Mamba2AttentionMetadata
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models"
)
self.chunk_size: int = chunk_size
def _compute_chunk_metadata(
self,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba2.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba2 (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
**kwargs: Any,
) -> Mamba2AttentionMetadata:
common = self._compute_common_metadata(
common_attn_metadata, num_accepted_tokens=kwargs.get("num_accepted_tokens")
)
seq_idx_p = None
cu_chunk_seqlen_p = None
last_chunk_indices_p = None
prep_initial_states = False
# Compute seq_idx for prefill only
if common.num_prefills > 0:
prep_initial_states = (
torch.any(common.has_initial_states_p).item()
if common.has_initial_states_p is not None
else False
)
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
seq_idx_p = torch.as_tensor(
seq_idx,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
return replace(
common,
prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size,
seq_idx_p=seq_idx_p,
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
last_chunk_indices_p=last_chunk_indices_p,
)

View File

@@ -0,0 +1,464 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from dataclasses import dataclass, replace
from typing import Any, ClassVar, TypeVar
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
M = TypeVar("M", bound="BaseMambaAttentionMetadata")
@dataclass
class BaseMambaAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_reqs: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill requests.
has_initial_states_p: torch.Tensor | None
query_start_loc_p: torch.Tensor | None
num_computed_tokens_p: torch.Tensor | None
state_indices_tensor_p: torch.Tensor | None
# The following tensors are used for decode requests and
# speculative decoding compatibility, and will be None if the batch
# has no decode requests.
state_indices_tensor_d: torch.Tensor | None
query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,]
# Number of accepted tokens for each spec sequence (for loading correct checkpoint)
# Includes the bonus token (so minimum is 1)
num_accepted_tokens: torch.Tensor | None # shape: [batch,]
# The following tensors are only used for prefix caching in all mode and
# are None if disabled
block_idx_last_scheduled_token: torch.Tensor | None
block_idx_first_scheduled_token_p: torch.Tensor | None
block_idx_last_computed_token: torch.Tensor | None
# The following tensor is only used for prefix caching in align mode
seq_lens: torch.Tensor
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
metadata_cls: type[M]
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
# Will be disabled if speculative decoding is used
supports_update_block_table: bool = True
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
# Enable speculative decoding support
self.speculative_config = vllm_config.speculative_config
self.compilation_config = vllm_config.compilation_config
self.num_spec_tokens: int = vllm_config.num_speculative_tokens
self.use_spec_decode = self.num_spec_tokens > 0
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
if self.compilation_config.max_cudagraph_capture_size is not None:
self.decode_cudagraph_max_bs = min(
self.decode_cudagraph_max_bs,
self.compilation_config.max_cudagraph_capture_size,
)
if self.vllm_config.cache_config.mamba_cache_mode == "all":
max_num_blocks = cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size,
)
# Speculative decoding not supported with prefix caching,
# so keep shape consistent with prefill buffer
# TODO: reduce this size as needed for decode-only cudagraph capture
self.state_indices_tensor_d = torch.empty(
(
self.decode_cudagraph_max_bs,
max_num_blocks,
),
dtype=torch.int32,
device=device,
)
self.block_idx_last_scheduled_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_last_computed_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
else:
self.state_indices_tensor_d = torch.empty(
(self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens),
dtype=torch.int32,
device=device,
)
# For speculative decoding, we need to store the following buffers
# for CUDA graph capture during decode
if self.num_spec_tokens > 0:
self.decode_num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self._init_reorder_batch_threshold(1, self.use_spec_decode)
if self.use_spec_decode:
self.supports_update_block_table = False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (
m.max_query_len <= 1 + self.num_spec_tokens
and m.num_reqs <= self.decode_cudagraph_max_bs
), (
"Mamba only supports decode-only full CUDAGraph capture. "
"Make sure all cudagraph capture sizes <= max_num_seq."
)
assert m.max_query_len == 1 + self.num_spec_tokens # decode-only
num_accepted_tokens = None
if self.num_spec_tokens > 0:
num_accepted_tokens = torch.diff(m.query_start_loc)
return self.build(0, m, num_accepted_tokens=num_accepted_tokens)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
*,
num_accepted_tokens: torch.Tensor | None = None,
**kwargs: Any,
) -> M:
"""
Default build implementation for Mamba-like attention backends.
Subclasses (e.g., Mamba2) can override to add additional metadata.
"""
return self._compute_common_metadata(
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
)
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,
mamba_block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Block index of the last computed token
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
# which is <= block index for the first scheduled token
block_idx_first_scheduled_token = (
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
)
# which is <= block index of the last scheduled token
block_idx_last_scheduled_token = (
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = torch.clamp(
block_idx_last_computed_token, min=0
)
# -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = torch.clamp(
block_idx_last_scheduled_token, min=0
)
return (
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
)
def _compute_common_metadata(
self,
common_attn_metadata: CommonAttentionMetadata,
*,
num_accepted_tokens: torch.Tensor | None = None,
) -> M:
"""
Compute metadata common to both Mamba1 and Mamba2.
"""
num_reqs = common_attn_metadata.num_reqs
# Treat multi-token queries as decode requests when
# speculative decoding is enabled. Otherwise, use the
# default decode threshold to prevent misclassification
# of prefill queries as decode requests.
decode_threshold = (
self.reorder_batch_threshold if num_accepted_tokens is not None else 1
)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=decode_threshold
)
)
# Need flags to indicate if there are initial states
has_initial_states_p = None
query_start_loc_p = None
query_start_loc_d = None
num_computed_tokens = None
num_computed_tokens_p = None
# for prefix caching
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
block_idx_last_computed_token = None
block_idx_last_scheduled_token = None
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.mamba_cache_mode == "all":
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
if state_indices_tensor.dim() == 1:
state_indices_tensor = state_indices_tensor.unsqueeze(-1)
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
if self.vllm_config.cache_config.mamba_cache_mode != "all":
state_indices_tensor_d = state_indices_tensor_d[
:, : 1 + self.num_spec_tokens
]
state_indices_tensor_p = state_indices_tensor_p[:, 0]
if num_decodes > 0 and self.use_spec_decode:
assert num_accepted_tokens is not None
query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1]
num_accepted_tokens = num_accepted_tokens[:num_decodes]
if num_prefills > 0:
if num_computed_tokens is None:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_p = (
num_computed_tokens[num_reqs - num_prefills : num_reqs] > 0
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(
query_start_loc_p_cpu,
device=common_attn_metadata.query_start_loc.device,
)
)
if self.vllm_config.cache_config.mamba_cache_mode == "all":
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
metadata = self.metadata_cls(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
num_accepted_tokens=num_accepted_tokens,
query_start_loc_d=query_start_loc_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
num_reqs=num_reqs,
seq_lens=common_attn_metadata.seq_lens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return self._update_metadata_for_cudagraph_capture(metadata)
def _update_metadata_for_cudagraph_capture(
self,
metadata: M,
) -> M:
"""
Update the metadata for cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
state_indices_tensor_d = metadata.state_indices_tensor_d
query_start_loc_d = metadata.query_start_loc_d
num_accepted_tokens = metadata.num_accepted_tokens
block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token
block_idx_last_computed_token = metadata.block_idx_last_computed_token
if (
metadata.num_prefills == 0
and metadata.num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
padded_bs = metadata.num_reqs
self.state_indices_tensor_d[: metadata.num_decodes].copy_(
state_indices_tensor_d, non_blocking=True
)
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID
if self.use_spec_decode:
assert query_start_loc_d is not None
assert num_accepted_tokens is not None
query_start_loc_d = query_start_loc_d[: padded_bs + 1]
self.decode_num_accepted_tokens[: metadata.num_decodes].copy_(
num_accepted_tokens, non_blocking=True
)
num_accepted_tokens = self.decode_num_accepted_tokens[:padded_bs]
num_accepted_tokens[metadata.num_decodes :] = (
1 # pad with 1st slot index
)
if self.vllm_config.cache_config.mamba_cache_mode == "all":
assert block_idx_last_scheduled_token is not None
assert block_idx_last_computed_token is not None
self.block_idx_last_scheduled_token[: metadata.num_decodes].copy_(
block_idx_last_scheduled_token[: metadata.num_decodes],
non_blocking=True,
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
: metadata.num_decode_tokens
]
self.block_idx_last_computed_token[: metadata.num_decodes].copy_(
block_idx_last_computed_token[: metadata.num_decodes],
non_blocking=True,
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
: metadata.num_decode_tokens
]
return replace(
metadata,
state_indices_tensor_d=state_indices_tensor_d,
query_start_loc_d=query_start_loc_d,
num_accepted_tokens=num_accepted_tokens,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_last_computed_token=block_idx_last_computed_token,
)
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
state_indices_tensor = mamba_get_block_table_tensor(
blk_table,
metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
if state_indices_tensor.dim() == 1:
state_indices_tensor = state_indices_tensor.unsqueeze(-1)
assert (
metadata.num_prefills + metadata.num_decodes
== state_indices_tensor.shape[0]
), (
"Mismatch in number of requests when updating block table."
f" Expected {metadata.num_prefills + metadata.num_decodes}, "
f"got {state_indices_tensor.shape[0]}."
)
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[metadata.num_decodes, metadata.num_prefills],
dim=0,
)
if self.vllm_config.cache_config.mamba_cache_mode != "all":
state_indices_tensor_d = state_indices_tensor_d[
:, : 1 + self.num_spec_tokens
]
state_indices_tensor_p = state_indices_tensor_p[:, 0]
new_metadata = replace(
metadata,
state_indices_tensor_d=state_indices_tensor_d,
state_indices_tensor_p=state_indices_tensor_p,
)
return self._update_metadata_for_cudagraph_capture(new_metadata)

View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLABackend, AiterMLAImpl
class AiterTritonMLABackend(AiterMLABackend):
@staticmethod
def get_name() -> str:
return "AITER_TRITON_MLA"
@staticmethod
def get_impl_cls() -> type["AiterTritonMLAImpl"]:
return AiterTritonMLAImpl
class AiterTritonMLAImpl(AiterMLAImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
from aiter.ops.triton.mha import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
result = self.flash_attn_varlen_func( # type: ignore[call-arg]
q,
k,
v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
# Transpose the LSE if Triton MHA is used:
# (q.shape[0], num_q_heads) to (num_q_heads, q.shape[0])
if type(result) is tuple and return_softmax_lse:
output, lse = result
lse = lse.T.contiguous()
return (output, lse)
return result

View File

@@ -0,0 +1,279 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import ClassVar
import torch
import vllm._custom_ops as ops
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [128]
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
@staticmethod
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 10
class SM100Workspace:
def __init__(self, initial_workspace_size):
self._workspace_buf = torch.empty(
initial_workspace_size, device="cuda", dtype=torch.uint8
)
self._block_size = 128 # Forced to 128
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
# (assumes all devices are similar)
self._sm_count = num_compute_units(0)
def get_buf(self):
return self._workspace_buf
def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int):
batch_size = attn_metadata.num_reqs
max_seq_len = attn_metadata.max_query_len
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seq_len * self._block_size,
batch_size,
self._sm_count,
num_kv_splits=num_kv_splits,
)
if self._workspace_buf.shape[0] < workspace_size:
self._workspace_buf.resize_(workspace_size)
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
MAX_HEADS = 128
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
q_pad_num_heads=MAX_HEADS,
**mla_args,
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl"
)
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
# issues. In case the code hangs, use:
# FORCE_NUM_KV_SPLITS=1
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
if force_num_kv_splits:
logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
self._num_kv_splits = int(force_num_kv_splits)
else:
self._num_kv_splits = -1 # => Auto-detect
# Share workspace buffer across all executions
self._workspace = g_sm100_workspace
def _sm100_cutlass_mla_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
sm_scale: float,
num_kv_splits: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
assert kv_c_and_k_pe_cache.ndim == 3, (
"kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
kv_c_and_k_pe_cache.ndim
)
)
B_q, H, D_q_nope = q_nope.shape
B_q_2, H_2, D_q_pe = q_pe.shape
assert (B_q == B_q_2) and (H == H_2)
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
D_latent = 512
D_rope = 64
assert D_q_nope == D_latent
assert D_q_pe == D_rope
assert D_ckv == D_latent + D_rope
MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0
assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}."
)
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert seq_lens.dtype == torch.int32, (
f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
)
assert page_table.dtype == torch.int32, (
f"page_table.dtype needs to be int32 but got {page_table.dtype}."
)
dtype = (
torch.bfloat16
if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype
)
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
lse = (
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode
else torch.Tensor()
)
ops.sm100_cutlass_mla_decode(
out,
lse,
q_nope,
q_pe,
kv_c_and_k_pe_cache,
seq_lens,
page_table,
workspace,
sm_scale,
num_kv_splits,
)
if H < MAX_HEADS:
# Extract the subsets of the outputs
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
out = out[:, :H]
return out, lse
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
# Adjust workspace size (if necessary)
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
# Run MLA
o, lse = self._sm100_cutlass_mla_decode(
q_nope,
q_pe,
kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table,
self._workspace.get_buf(),
self.scale,
self._num_kv_splits,
)
return o, (lse if self.need_to_return_lse_for_decode else None)

View File

@@ -0,0 +1,361 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla,
get_flash_attn_version,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
get_scheduler_metadata,
)
logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
@staticmethod
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
return FlashAttnMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
return FlashAttnMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 9
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if not flash_attn_supports_mla():
return "FlashAttention MLA not supported on this device"
return None
@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
query_start_loc: torch.Tensor
max_query_len: int
max_seq_len: int
scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0
@dataclass
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
pass
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
super().__init__(
kv_cache_spec,
layer_names,
vllm_config,
device,
FlashAttnMLAMetadata,
supports_dcp_with_varlen=(interleave_size == 1),
)
self.max_num_splits = 0 # No upper bound on the number of splits.
self.fa_aot_schedule = get_flash_attn_version() == 3
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.fa_aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros(
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32,
device=self.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = (
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
if vllm_is_batch_invariant():
self.max_num_splits = 1
def _schedule_decode(
self,
num_reqs,
cu_query_lens,
max_query_len,
seqlens,
max_seq_len,
causal,
max_num_splits,
):
if self.fa_aot_schedule:
return get_scheduler_metadata(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads * self.dcp_world_size,
num_heads_kv=1,
headdim=self.mla_dims.qk_rope_head_dim,
cache_seqlens=seqlens,
qkv_dtype=self.kv_cache_spec.dtype,
headdim_v=self.mla_dims.kv_lora_rank,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
num_splits=max_num_splits,
)
return None
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph
max_num_splits = 0
if (
self.use_full_cuda_graph
and self.max_cudagraph_size is not None
and num_decode_tokens <= self.max_cudagraph_size
):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_is_batch_invariant():
max_num_splits = 1
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_device.shape[0],
cu_query_lens=query_start_loc_device,
max_query_len=max_query_len,
seqlens=seq_lens_device,
max_seq_len=max_seq_len,
causal=True,
max_num_splits=max_num_splits,
)
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
# Ensure the persistent buffer is large enough
assert n <= self.scheduler_metadata.shape[0], (
f"Scheduler metadata size {n} exceeds buffer size "
f"{self.scheduler_metadata.shape[0]}"
)
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
query_start_loc=query_start_loc_device,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
return metadata
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashAttnMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttnMLAImpl"
)
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
)
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
# to prevent invalid grid configuration during graph capture.
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
attn_out = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
q_v=q_nope,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=attn_metadata.decode.query_start_loc,
max_seqlen_k=attn_metadata.decode.max_seq_len,
seqused_k=attn_metadata.decode.seq_lens,
block_table=attn_metadata.decode.block_table,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=self.need_to_return_lse_for_decode,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
cp_world_size=self.dcp_world_size,
cp_rank=self.dcp_rank,
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
)
if self.need_to_return_lse_for_decode:
o, lse = attn_out
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
else:
o = attn_out
return o, None

View File

@@ -0,0 +1,202 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import KVCacheLayoutType
logger = init_logger(__name__)
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [32, 64]
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
@staticmethod
def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl
@staticmethod
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 10
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
# FlashInfer MLA kernel requires qk_nope_head_dim == 128
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if qk_nope_head_dim != 128:
return (
f"FlashInfer MLA kernel requires qk_nope_head_dim == 128, "
f"but got {qk_nope_head_dim}"
)
return None
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"
g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device="cuda",
)
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashInferMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferMLAImpl"
)
self._workspace_buffer = g_fi_workspace
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if isinstance(q, tuple):
q_nope, q_pe = q
q = torch.cat([q_nope, q_pe], dim=-1)
# trtllm API requires extra dimension q_len_per_request for MTP
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
logger.warning_once(
"""FlashInferMLAImpl got a query of uneven length.
This usually indicates an issue in batch reordering
or incorrect setup in dummy_run."""
)
q = q.unsqueeze(1)
else:
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])
# TODO: Return LSE pending support from Flashinfer API:
# https://github.com/flashinfer-ai/flashinfer/pull/1566
return o, None

View File

@@ -0,0 +1,353 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""FlashInfer MLA Sparse Attention Backend.
This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k
for models like DeepSeek-V3.2 that use index-based sparse attention.
For sparse MLA:
- block_tables shape changes from [batch_size, max_num_blocks] (dense)
to [batch_size, q_len_per_request, sparse_mla_top_k] (sparse)
- The sparse indices represent physical cache slot positions to attend to
- sparse_mla_top_k parameter must be set to the topk value
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar
import numpy as np
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.sparse_utils import (
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import KVCacheLayoutType
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLASparseBackend(AttentionBackend):
"""FlashInfer MLA backend with sparse attention support.
This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k
for models like DeepSeek-V3.2 that use index-based sparse attention.
"""
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [32, 64]
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA_SPARSE"
@staticmethod
def get_impl_cls() -> type["FlashInferMLASparseImpl"]:
return FlashInferMLASparseImpl
@staticmethod
def get_builder_cls() -> type["FlashInferMLASparseMetadataBuilder"]:
return FlashInferMLASparseMetadataBuilder
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
# FlashInfer sparse MLA targets Blackwell (SM 10.x)
return capability.major == 10
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
# FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if qk_nope_head_dim != 128:
return (
f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, "
f"but got {qk_nope_head_dim}"
)
# Check for index_topk which indicates sparse model
if not hasattr(hf_text_config, "index_topk"):
return "FlashInfer MLA Sparse requires model with index_topk config"
return None
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"
@dataclass
class FlashInferMLASparseMetadata(AttentionMetadata):
"""Attention metadata for FlashInfer MLA Sparse backend."""
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int
# Query start locations
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
req_id_per_token: torch.Tensor
# Sequence lengths for all requests (context + query)
seq_lens: torch.Tensor
# Sparse-specific
block_size: int = 64
topk_tokens: int = 2048
class FlashInferMLASparseMetadataBuilder(
AttentionMetadataBuilder[FlashInferMLASparseMetadata]
):
"""Builder for FlashInfer MLA Sparse attention metadata."""
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
) -> None:
self.vllm_config = vllm_config
self.layer_names = layer_names
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.device = device
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.req_id_per_token_buffer = torch.empty(
(vllm_config.scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> FlashInferMLASparseMetadata:
cm = common_attn_metadata
num_tokens = cm.num_actual_tokens
# Build req_id_per_token mapping
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
)
req_id_per_token_tensor = self.req_id_per_token_buffer[:num_tokens]
return FlashInferMLASparseMetadata(
num_reqs=cm.num_reqs,
max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens,
query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping,
block_table=cm.block_table_tensor,
req_id_per_token=req_id_per_token_tensor,
seq_lens=cm.seq_lens,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
)
# Global workspace buffer (lazily initialized)
_fi_sparse_workspace: torch.Tensor | None = None
def _get_workspace_buffer(device: torch.device) -> torch.Tensor:
global _fi_sparse_workspace
if _fi_sparse_workspace is None:
_fi_sparse_workspace = torch.zeros(
FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=device,
)
return _fi_sparse_workspace
class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata]):
"""FlashInfer MLA Sparse implementation.
Uses the TRT-LLM MLA kernel with sparse_mla_top_k parameter for
sparse attention computation.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
topk_indice_buffer: torch.Tensor | None = None,
indexer: "Indexer | None" = None,
**mla_args,
) -> None:
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashInferMLASparseImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferMLASparseImpl"
)
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# MLA-specific dimensions
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.qk_nope_head_dim: int = mla_args["qk_nope_head_dim"]
self.qk_rope_head_dim: int = mla_args["qk_rope_head_dim"]
assert indexer is not None, "Indexer required for sparse MLA"
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self._workspace_buffer: torch.Tensor | None = None
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashInferMLASparseMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
num_actual_toks = q.shape[0]
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
topk_indices_physical, seq_lens = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token[:num_actual_toks],
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
return_valid_counts=True,
)
if self._workspace_buffer is None:
self._workspace_buffer = _get_workspace_buffer(q.device)
if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
query=q.unsqueeze(1),
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=topk_indices_physical.unsqueeze(1),
seq_lens=seq_lens,
max_seq_len=attn_metadata.topk_tokens,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
sparse_mla_top_k=attn_metadata.topk_tokens,
)
return o.view(-1, o.shape[-2], o.shape[-1]), None

View File

@@ -0,0 +1,317 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
)
from vllm.v1.attention.ops.flashmla import (
FlashMLASchedMeta,
flash_mla_with_kvcache,
flash_mla_with_kvcache_fp8,
get_mla_metadata,
get_mla_metadata_dense_fp8,
is_flashmla_dense_supported,
)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]
@staticmethod
def get_name() -> str:
return "FLASHMLA"
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major in [9, 10]
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if use_sparse:
from vllm.v1.attention.ops.flashmla import is_flashmla_sparse_supported
return is_flashmla_sparse_supported()[1]
else:
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
return is_flashmla_dense_supported()[1]
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
scheduler_metadata: FlashMLASchedMeta
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
# ^ TODO(matt): tune this
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
)
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
num_sms = num_compute_units(self.device.index)
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
(num_sms, 8),
device=self.device,
dtype=torch.int32,
)
self.cg_buf_num_splits = torch.empty(
(vllm_config.scheduler_config.max_num_seqs + 1),
device=self.device,
dtype=torch.int32,
)
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> FlashMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
# we use the max but all should be the same due to uniform length requirement
max_query_len = query_lens_cpu.max().item()
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
scheduler_metadata, _ = get_mla_metadata(
seq_lens_device,
num_q_tokens_per_head_k,
1, # MQA for the decode path
is_fp8_kvcache=self.is_fp8_kvcache,
)
if self.is_fp8_kvcache:
tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
seq_lens_device,
num_q_tokens_per_head_k,
1, # MQA for the decode path
)
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
scheduler_metadata=scheduler_metadata,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
is_supported, reason = is_flashmla_dense_supported()
assert is_supported, reason
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl"
)
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# TODO: (zyongye) decode function for mla here
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
# mypy assertion: q is now always a tensor
assert isinstance(q, torch.Tensor)
num_decodes = attn_metadata.num_decodes
q = reshape_query_for_spec_decode(q, num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata
if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
device = q.device
dtype = torch.int32
B = q.shape[0]
# block_table shape: [batch_size, max_num_blocks_per_seq]
# The number of blocks per sequence is in the second dimension
topk = attn_metadata.decode.block_table.shape[-1]
B_TOPK = 64
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
end_block_idx = topk // B_TOPK
# Single partition => num_sm_parts = 1
# TileSchedulerMetaDataSize = 8, layout:
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
# begin_n_split_idx, _, _, _]
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
tile_scheduler_metadata[0, 0] = 0 # begin_idx
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
tile_scheduler_metadata[0, 3] = end_block_idx
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
# fields [5..7] stay 0
# Non-split path ignores num_splits, but the API requires it:
# zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits
if self.kv_cache_dtype.startswith("fp8"):
o, lse = flash_mla_with_kvcache_fp8(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
num_splits=scheduler_metadata.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
else:
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata,
softmax_scale=self.scale,
causal=True,
is_fp8_kvcache=False,
)
o = reshape_attn_output_for_spec_decode(o)
return o, lse

View File

@@ -0,0 +1,847 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar
import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.sparse_utils import (
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.attention.ops.flashmla import (
FlashMLASchedMeta,
flash_mla_sparse_fwd,
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
# For FP8 sparse attention we have two impelementations:
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
# done by treating all tokens as single batch.
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
# the FP8 decode kernel for decode.
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
# prefill kernel requires padding the numer of heads to 128 while the decode does not
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
# batch mode (#2).
MIN_HEADS_FOR_BF16_PREFILL = 32
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8_ds_mla",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE"
@staticmethod
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
return FlashMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLASparseImpl"]:
return FlashMLASparseImpl
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major in [9, 10]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if cache_dtype_str == "fp8_ds_mla":
# custom storage fromat is 656 bytes
# see FlashMLA readme.md for details
return (num_blocks, block_size, 656)
else:
return (num_blocks, block_size, head_size)
@dataclass
class FlashMLASparseMetadata(AttentionMetadata):
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
req_id_per_token: torch.Tensor
block_size: int = 64
topk_tokens: int = 2048
@dataclass
class FP8KernelMetadata:
scheduler_metadata: FlashMLASchedMeta
dummy_block_table: torch.Tensor
cache_lens: torch.Tensor
@dataclass
class FP8SeparatePrefillDecode:
@dataclass
class Decode:
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
decode_query_len: int # needed for reshape in spec decode
@dataclass
class Prefill:
# Sequence lengths (context + query) for prefill requests
# Shape: [num_prefill_reqs]
seq_lens: torch.Tensor
# Request ID for each token: -1 for decode tokens, request index
# (0, 1, 2, ...) for prefill tokens.
# Shape: [num_actual_tokens]
request_ids: torch.Tensor
# Workspace start offsets for all prefill requests
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
# 0-indexed within each chunk. Used to map prefill tokens to workspace
# offsets in convert_logical_index_to_physical_index
workspace_starts: torch.Tensor
@dataclass
class Chunk:
"""Metadata for a chunk of prefill requests.
Prefill requests may be chunked to fit within the fixed workspace size.
"""
seq_lens: torch.Tensor
tokens_slice: slice
block_table: torch.Tensor
req_start_idx: int
workspace_starts: torch.Tensor
chunk_tot_seqlen: int
chunks: list[Chunk]
num_prefills: int = 0
num_decodes: int = 0
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
decode: Decode | None = None
prefill: Prefill | None = None
fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
fp8_use_mixed_batch: bool = False
def get_prefill_workspace_size(max_model_len: int):
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
# May be tuned later.
# Memory usage: 5 * max_model_len * 576 * 2 bytes
# Example: DeepSeek-V3.2 with max_model_len=163840 ->
# 5 * 163840 * 576 * 2 = ~900 MB
# This fits nicely below the typical MoE workspace size of >2GB so this is "free"
return max_model_len * 5
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
) -> None:
self.vllm_config = vllm_config
self.layer_names = layer_names
cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
# Treat requests with query length <= 1 as decodes to match the
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
sm_count = num_compute_units(device.index)
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
# FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
self.fp8_decode_padded_heads = (
FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
# Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
self.topk_tokens_tensor = torch.full(
(max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
)
# Shape: [max_num_seqs], all elements = max_model_len
self.max_model_len_tensor = torch.full(
(max_num_seqs,),
self.model_config.max_model_len,
device=device,
dtype=torch.int32,
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty(
(max_num_seqs, 1), dtype=torch.int32, device=self.device
)
# Equation taken from FlashMLA/csrc/api/sparse_decode.h
# For sparse FP8 decode, the formula depends on architecture:
# - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
# - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
# - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
# For max buffer size, use s_q = 1 (the case that produces largest output)
# Use padded head count since that's what will be passed to the kernel
h_q = self.fp8_decode_padded_heads
if current_platform.is_device_capability_family(100):
# SM100 head64 or head64x2 uses full SM count
max_num_sm_parts = sm_count
else:
# SM90 uses h_q/64 divisor
max_num_sm_parts = sm_count // max(1, h_q // 64)
self.tile_scheduler_metadata_buffer = torch.empty(
# TileSchedulerMetaDataSize = 8
# see: FlashMLA/csrc/params.h
(max_num_sm_parts, 8),
dtype=torch.int32,
device=device,
)
# Sized for per-request batching (num_decodes + 1)
self.num_splits_buffer = torch.empty(
(max_num_seqs + 1,),
dtype=torch.int32,
device=device,
)
self.req_id_per_token_buffer = torch.empty(
(vllm_config.scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=device,
)
def _build_fp8_mixed_decode_prefill(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
"""Build FP8 metadata treating all tokens as one mixed batch.
This matches main branch's approach and avoids the BF16 prefill kernel
which has head padding overhead when num_heads is small (high TP case).
"""
num_tokens = common_attn_metadata.num_actual_tokens
# Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads
# Build metadata for all tokens as a single batch
scheduler_metadata, _ = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
num_q_tokens_per_head_k=num_tokens * padded_heads,
topk=self.topk_tokens,
num_heads_q=padded_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=scheduler_metadata,
cache_lens=self.max_model_len_tensor[:1],
dummy_block_table=self.dummy_block_table[:1],
)
return fp8_metadata
def _build_fp8_separate_prefill_decode(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
num_tokens = common_attn_metadata.num_actual_tokens
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold or 1,
require_uniform=True,
)
)
FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
fp8_metadata = FP8Meta(
num_decodes=num_decodes,
num_prefills=num_prefills,
num_decode_tokens=num_decode_tokens,
num_prefill_tokens=num_prefill_tokens,
)
# Extract prefill sequence lengths (context + query, not just query)
# Decode requests come first in the batch, prefill requests follow
prefill_seq_lens = None
prefill_request_id = None
prefill_workspace_starts = None
prefill_chunks = None
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
prefill_seq_lens = seq_lens[num_decodes:]
# Build prefill_request_id: -1 for decode, request index for
# prefill. This enables a single
# convert_logical_index_to_physical_index call for all tokens
prefill_request_id = torch.full(
(num_tokens,), -1, dtype=torch.int32, device=self.device
)
# Map prefill tokens to their request IDs (0, 1, 2, ...)
for req_idx in range(num_prefills):
# Get query token range for this prefill request
global_req_idx = num_decodes + req_idx
req_query_start = query_start_loc_cpu[global_req_idx]
req_query_end = query_start_loc_cpu[global_req_idx + 1]
prefill_request_id[req_query_start:req_query_end] = req_idx
# will be adjusted by chunk loop
prefill_workspace_starts_cpu = torch.zeros(
num_prefills, dtype=torch.int32, pin_memory=True
)
prefill_workspace_starts_cpu[1:] = torch.cumsum(
prefill_seq_lens_cpu[:-1], dim=0
)
# populated by non-blocking copy after prefill_workspace_starts_cpu is
# updated by each chunk
prefill_workspace_starts = torch.empty(
num_prefills, dtype=torch.int32, device=self.device
)
# Chunk prefill requests to fit within workspace size
max_prefill_buffer_size = get_prefill_workspace_size(
self.vllm_config.model_config.max_model_len
)
chunk_bounds = split_prefill_chunks(
prefill_seq_lens_cpu, max_prefill_buffer_size
)
prefill_chunks = []
for chunk_start, chunk_end in chunk_bounds:
# Adjust workspace_starts in-place per chunk to be
# 0-indexed within each chunk
# Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
# Initial: workspace_starts=[0,10,25,45]
# After: workspace_starts=[0,10,0,20]
# (chunk 0 starts at 0, chunk 1 starts at 0)
offset = prefill_workspace_starts_cpu[chunk_start].item()
prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset
chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
tokens_slice = slice(token_start, token_end)
# Create chunk view of gpu tensor
chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
chunk_block_table = common_attn_metadata.block_table_tensor[
num_decodes + chunk_start : num_decodes + chunk_end
]
prefill_chunks.append(
FP8Meta.Prefill.Chunk(
seq_lens=chunk_seq_lens,
tokens_slice=tokens_slice,
block_table=chunk_block_table,
req_start_idx=chunk_start,
workspace_starts=chunk_workspace_starts,
chunk_tot_seqlen=chunk_tot_seqlen,
)
)
prefill_workspace_starts.copy_(
prefill_workspace_starts_cpu, non_blocking=True
)
fp8_metadata.prefill = FP8Meta.Prefill(
seq_lens=prefill_seq_lens,
request_ids=prefill_request_id,
workspace_starts=prefill_workspace_starts,
chunks=prefill_chunks,
)
if num_decodes > 0:
# Compute decode_query_len for spec decode (uniform due to require_uniform)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
# Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads
scheduler_metadata, _ = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
num_q_tokens_per_head_k=decode_query_len * padded_heads,
topk=self.topk_tokens,
num_heads_q=padded_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=scheduler_metadata,
dummy_block_table=self.dummy_block_table[:num_decodes],
cache_lens=self.max_model_len_tensor[:num_decodes],
)
fp8_metadata.decode = FP8Meta.Decode(
kernel_metadata=kernel_meta,
decode_query_len=decode_query_len,
)
return fp8_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> FlashMLASparseMetadata:
cm = common_attn_metadata
num_tokens = cm.num_actual_tokens
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata: (
FlashMLASparseMetadata.FP8SeparatePrefillDecode
| FlashMLASparseMetadata.FP8KernelMetadata
| None
) = None
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
if self.use_fp8_kv_cache:
if fp8_use_mixed_batch:
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
else:
fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)
metadata = FlashMLASparseMetadata(
num_reqs=cm.num_reqs,
max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens,
query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping,
block_table=cm.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
fp8_use_mixed_batch=fp8_use_mixed_batch,
)
return metadata
class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
@staticmethod
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
# FP8 decode kernel only supports h_q = 64 or 128
# Compute padded head count for decode
return 64 if num_heads <= 64 else 128
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
topk_indice_buffer: torch.Tensor | None = None,
indexer: "Indexer | None" = None,
**mla_args,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
# Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
self.prefill_padding = (
128 if current_platform.is_device_capability_family(100) else 64
)
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
vllm_config = get_current_vllm_config()
assert vllm_config is not None and vllm_config.model_config is not None
prefill_workspace_size = get_prefill_workspace_size(
vllm_config.model_config.max_model_len
)
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
(self.prefill_bf16_workspace,) = (
current_workspace_manager().get_simultaneous(
(self.prefill_workspace_shape, torch.bfloat16)
)
)
def _forward_bf16_kv(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
)
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
def _forward_fp8_kv_separate_prefill_decode(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
num_decodes = fp8_metadata.num_decodes
prefill_request_ids = None
prefill_workspace_starts = None
has_prefill_workspace = False
if fp8_metadata.prefill is not None:
prefill_request_ids = fp8_metadata.prefill.request_ids
prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
has_prefill_workspace = True
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
# For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
# For BF16 cache: always use global cache slots (no workspace)
# prefill_workspace_starts has been adjusted in-place per chunk so
# prefill indices automatically come out chunk-local
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
HAS_PREFILL_WORKSPACE=has_prefill_workspace,
prefill_workspace_request_ids=prefill_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
# Reshape q: (num_decode_tokens, num_heads, head_dim)
# -> (num_decodes, seq_len, num_heads, head_dim)
q = reshape_query_for_spec_decode(q, num_decodes)
seq_len = q.shape[1]
# Reshape topk_indices: (num_decode_tokens, topk)
# -> (num_decodes, seq_len, topk)
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
assert fp8_metadata.decode is not None
attn_out, _ = self._fp8_flash_mla_kernel(
q=q,
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
topk_indices=topk_indices,
kernel_metadata=fp8_metadata.decode.kernel_metadata,
)
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
# -> (num_decode_tokens, num_heads, head_dim_v)
return reshape_attn_output_for_spec_decode(attn_out)
num_decode_tokens = fp8_metadata.num_decode_tokens
num_prefill_tokens = fp8_metadata.num_prefill_tokens
# Pure decode: direct call without allocation
if num_decode_tokens > 0 and num_prefill_tokens == 0:
assert fp8_metadata.decode is not None
attn_out = _fp8_decode(q, topk_indices)
else:
# Mixed or pure prefill: allocate output tensor
attn_out = q.new_empty(
(attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
dtype=q.dtype,
device=q.device,
)
if num_decode_tokens > 0:
attn_out[:num_decode_tokens] = _fp8_decode(
q[:num_decode_tokens], topk_indices[:num_decode_tokens]
)
assert fp8_metadata.prefill is not None
for chunk in fp8_metadata.prefill.chunks:
chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
ops.cp_gather_and_upconvert_fp8_kv_cache(
kv_c_and_k_pe_cache,
chunk_workspace,
chunk.block_table,
chunk.seq_lens,
chunk.workspace_starts,
len(chunk.block_table),
)
chunk_q = q[chunk.tokens_slice]
chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]
attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
chunk_q,
chunk_workspace,
chunk_topk_indices_workspace,
)
return attn_out
def _forward_fp8_kv_mixed_batch(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
"""Mixed batch FP8 forward path that treats all tokens as one batch.
This is equivalent to main branch's approach and avoids the BF16
prefill kernel which has head padding overhead when num_heads is small.
Used when use_mixed_batch is True.
"""
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
)
assert attn_metadata.fp8_extra_metadata is not None
assert isinstance(
attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
)
fp8_metadata = attn_metadata.fp8_extra_metadata
_attn_out, _ = self._fp8_flash_mla_kernel(
q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk)
kernel_metadata=fp8_metadata,
)
# Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
return _attn_out.squeeze(0)
def _fp8_flash_mla_kernel(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
# q shape: (batch, seq_len, num_heads, head_dim)
actual_num_heads = q.size(2)
padded_num_heads = self.fp8_decode_padded_heads
# Pad query if needed (kernel only supports h_q = 64 or 128)
if actual_num_heads < padded_num_heads:
logger.warning_once(
f"Padding num_heads from {actual_num_heads} to "
f"{padded_num_heads} for FP8 sparse decode kernel"
)
q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
q_padded[:, :, :actual_num_heads, :] = q
q = q_padded
out, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
block_table=kernel_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=kernel_metadata.cache_lens,
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
is_fp8_kvcache=True,
indices=topk_indices,
softmax_scale=self.softmax_scale,
)
# Slice output back to actual head count if we padded
if actual_num_heads < padded_num_heads:
out = out[:, :, :actual_num_heads, :]
return out, lse
def _bf16_flash_mla_kernel(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1, 1, kv_c_and_k_pe_cache.shape[-1]
)
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell
if self.num_heads % self.prefill_padding != 0:
assert self.prefill_padding % self.num_heads == 0
logger.warning_once(
f"Padding num_heads from {self.num_heads} to "
f"{self.prefill_padding} for BF16 sparse prefill kernel"
)
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
q_padded[:, : self.num_heads, :] = q
q = q_padded
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = flash_mla_sparse_fwd(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
)[0]
output = output[:, : self.num_heads, :]
return output
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
# Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
num_actual_toks = q.shape[0]
# Get topk indices
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
if not use_fp8_cache:
attn_out = self._forward_bf16_kv(
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
elif attn_metadata.fp8_use_mixed_batch:
attn_out = self._forward_fp8_kv_mixed_batch(
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
else:
attn_out = self._forward_fp8_kv_separate_prefill_decode(
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
return attn_out, None

View File

@@ -0,0 +1,386 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "DEEPSEEK_V32_INDEXER"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1 if current_platform.is_rocm() else 64]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 128]
@staticmethod
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
return DeepseekV32IndexerMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
token_to_seq: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
@dataclass
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
use_large_context_topk: bool
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
@dataclass
class DeepseekV32IndexerMetadata:
# FIXME (zyongye)
# hacky way to access the data now, need to be in chunked meta
seq_lens: torch.Tensor
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# The dimension of the attention heads
head_dim: int
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
decode: DeepSeekV32IndexerDecodeMetadata | None = None
prefill: DeepseekV32IndexerPrefillMetadata | None = None
# TODO (zyongye) optimize this, this is now vibe coded
def kv_spans_from_batches(
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
start_seq_loc: 1D long tensor [B+1], cumulative counts of
selected tokens per batch.
Example: [0, 2, 4, 7] ->
batch sizes (selected) [2, 2, 3], N=7 tokens total.
seq_len_per_batch: 1D long tensor [B],
full sequence length (KV length) of each batch.
Example: [5, 9, 4].
Returns:
start_tensor: 1D long tensor [N], start offset in the
concatenated KV cache for each token's batch.
end_location: 1D long tensor [N],
**exclusive** end = start + token's local position.
(So the attended KV slice is kv[start:end].)
Assumes each batch contributes its full `seq_len_per_batch[i]`
keys to the KV cache, andthe selected tokens within a batch
are the **last** `counts[i]` positions of that sequence.
"""
q = start_seq_loc.to(dtype=torch.long)
L = seq_len_per_batch.to(dtype=torch.long)
assert q.dim() == 1 and L.dim() == 1
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
# Selected tokens per batch and totals
counts = q[1:] - q[:-1] # [B]
N = int(q[-1].item()) # total selected tokens
B = L.numel()
if N == 0:
return (
torch.empty(0, dtype=torch.long, device=device),
torch.empty(0, dtype=torch.long, device=device),
)
# KV start offsets per batch in the concatenated KV cache
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
# For each selected token, which batch does it belong to?
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
# Map batch KV start to each token
start_tensor = kv_starts_per_batch[batch_id] # [N]
# End-align local positions inside each batch:
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
L_expand = torch.repeat_interleave(L, counts) # [N]
m_expand = torch.repeat_interleave(counts, counts) # [N]
# position within the selected block: 1..counts[b]
pos_within = (
torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
)
local_pos = L_expand - m_expand + pos_within # [N], 1-based
end_location = start_tensor + local_pos # exclusive end
return start_tensor.int().to(device), end_location.int().to(device)
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
# Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
# The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
# within the flashmla_sparse workspace.
# For DeepSeek-V3.2, the max_model_len is 163840.
# 40 * 163840 * 132 = 865075200 bytes = 825 MB
return max_model_len * 40
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
scheduler_config = self.vllm_config.scheduler_config
# NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config)
self.num_speculative_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config
else 0
)
if self.num_speculative_tokens > 1:
raise ValueError(
"Sparse MLA only supports "
"num_speculative_tokens <= 1 because the DeepGEMM "
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
f"Got num_speculative_tokens={self.num_speculative_tokens}."
)
self.reorder_batch_threshold += self.num_speculative_tokens
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
)
# See: DeepGMM/csrc/apis/attention.hpp
self.scheduler_metadata_buffer = torch.empty(
(self.num_sms + 1, 2), dtype=torch.int32, device=self.device
)
def build_one_prefill_chunk(
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
):
prefill_query_start_loc = (
query_start_loc_cpu[reqs_start : reqs_end + 1]
- query_start_loc_cpu[reqs_start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
).to(self.device)
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = (
torch.cat(
[
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
]
)
.to(torch.int32)
.to(self.device)
)
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu[num_decodes:],
self.max_prefill_buffer_size,
request_offset=num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start,
reqs_end,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor,
)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks,
)
decode_metadata = None
if num_decodes > 0:
torch.diff(
common_attn_metadata.query_start_loc[: num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes],
)
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
# Decide which top-k kernel to use based on batch size and sequence length
batch_size = num_decodes
_is_large_context = common_attn_metadata.max_seq_len > 8192
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
use_large_context_topk = batch_size <= 128 and _is_large_context
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
else:
offsets = None
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata

View File

@@ -0,0 +1,284 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
@staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@dataclass
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: torch.Tensor | None = None
# The page indices of the paged kv cache
paged_kv_indices: torch.Tensor | None = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: torch.Tensor | None = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: torch.Tensor | None = None
# The dtype of MLA out tensor
attn_out_dtype: torch.dtype = torch.bfloat16
# The max query output length: int
max_qo_len: int | None = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
)
self.compilation_config = vllm_config.compilation_config
self.decode_attn_out_dtype = vllm_config.model_config.dtype
# kernel block size is always 1.
max_num_pages_per_req = vllm_config.model_config.max_model_len
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
# Preparing persistent buffers
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
# paged_kv_last_page_len is always 1s (kernel block size is always 1),
# so we create it once and reuse slices in both eager and cudagraph modes.
self.paged_kv_last_page_len = torch.ones(
max_num_reqs, dtype=torch.int32, device=device
)
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
self.paged_kv_indices = torch.zeros(
max_num_pages, dtype=torch.int32, device=device
)
self.qo_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata:
# kernel block size is always 1, although the kv block size is not 1.
device = self.device
num_reqs = seq_lens_device.size(0)
mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask]
# kernel block size is always 1, so each page has exactly 1 token.
# last_page_len is always 1 - just slice the pre-initialized buffer.
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
paged_kv_indptr = torch.cat(
[
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
]
)
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_qo_len = qo_len.max().item()
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True
)
self.paged_kv_indices[num_actual_pages:].fill_(-1)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
self.paged_kv_indptr[: 1 + num_reqs].copy_(
paged_kv_indptr, non_blocking=True
)
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
# paged_kv_last_page_len already uses the pre-initialized buffer slice
# (set above), so no copy needed - buffer is always 1s.
self.qo_indptr[: 1 + num_reqs].copy_(
query_start_loc_device, non_blocking=True
)
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
qo_indptr = self.qo_indptr[: 1 + num_reqs]
else:
qo_indptr = torch.arange(
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
)
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_qo_len=max_qo_len,
attn_out_dtype=self.decode_attn_out_dtype,
)
return attn_metadata
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
assert num_heads == 16 or num_heads == 128, (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value."
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
output = self.flash_attn_varlen_func( # type: ignore[call-arg]
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
assert attn_metadata.decode.max_qo_len is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(
B,
self.num_heads,
self.kv_lora_rank,
dtype=attn_metadata.decode.attn_out_dtype,
device=q.device,
)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
rocm_aiter_ops.mla_decode_fwd(
q,
kv_buffer,
o,
self.scale,
attn_metadata.decode.qo_indptr,
attn_metadata.decode.max_qo_len,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len,
q_scale=layer._q_scale,
kv_scale=layer._k_scale,
)
return o, None

View File

@@ -0,0 +1,368 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar
import numpy as np
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
@triton.jit
def fetch_id_to_ragged_kernel(
in_tensor_ptr, # [num_seq, topk]
cumsum_ptr, # [num_seq + 1]
out_tensor_ptr, # [max_num_seq * topk]
in_tensor_ptr_stride,
TOPK: tl.constexpr,
TOKEN_NUM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
seq_id = tl.program_id(0)
block_id = tl.program_id(1)
offset = tl.arange(0, BLOCK_SIZE)
token_start = tl.load(cumsum_ptr + seq_id)
token_end = tl.load(cumsum_ptr + seq_id + 1)
token_num = token_end - token_start
row_offset = block_id * BLOCK_SIZE
if row_offset >= token_num:
return
in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset
in_tensor_mask = (row_offset + offset) < TOPK
in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask)
out_tensor_offset = token_start + row_offset + offset
out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask
tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask)
def fetch_id_to_ragged_triton(
in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk
):
num_tokens = in_tensor.size(0)
block_size = 64
num_block_per_row = triton.cdiv(topk, block_size)
grid = (
num_tokens,
num_block_per_row,
)
fetch_id_to_ragged_kernel[grid](
in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size
)
class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA_SPARSE"
@staticmethod
def get_metadata_cls() -> type["ROCMAiterMLASparseMetadata"]:
return ROCMAiterMLASparseMetadata
@staticmethod
def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
return ROCMAiterMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
return ROCMAiterMLASparseImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass
class ROCMAiterMLASparseMetadata(AttentionMetadata):
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
req_id_per_token: torch.Tensor
qo_indptr: torch.Tensor
paged_kv_last_page_len: torch.Tensor
paged_kv_indices: torch.Tensor
paged_kv_indptr: torch.Tensor
paged_kv_indptr_rest: torch.Tensor
block_size: int = 1
topk_tokens: int = 2048
@dataclass
class ROCMAiterMLASparseMetadataBuilder(
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.topk_tokens_tensor = torch.tensor(
[self.topk_tokens], device=device, dtype=torch.int32
)
self.max_model_len_tensor = torch.tensor(
[self.model_config.max_model_len], device=device, dtype=torch.int32
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty(
(1, 1), dtype=torch.int32, device=self.device
)
self.req_id_per_token_buffer = torch.empty(
(vllm_config.scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=device,
)
self.qo_indptr = torch.arange(
0, max_num_batched_tokens + 1, dtype=torch.int32, device=device
)
self.paged_kv_last_page_len = torch.ones(
max_num_batched_tokens, dtype=torch.int32, device=device
)
# These two needs to be calculated in runtime,
# but we still needs to prepare the buffer
self.paged_kv_indices = torch.zeros(
[max_num_batched_tokens * self.topk_tokens],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros(
[max_num_batched_tokens + 1], dtype=torch.int32, device=device
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> ROCMAiterMLASparseMetadata:
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
)
self.paged_kv_indices.fill_(0)
self.paged_kv_indptr.fill_(0)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
qo_indptr = self.qo_indptr[: num_tokens + 1]
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens]
paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens]
paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1]
paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :]
metadata = ROCMAiterMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
qo_indptr=qo_indptr,
paged_kv_last_page_len=paged_kv_last_page_len,
paged_kv_indices=paged_kv_indices,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indptr_rest=paged_kv_indptr_rest,
)
return metadata
# Take from
# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_prefill.py#L72
def reference_mla_sparse_prefill(
q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int
) -> tuple[torch.Tensor, torch.Tensor]:
import math
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
skv = kv.shape[0]
sq = q.shape[0]
topk = indices.shape[-1]
dqk = q.shape[-1]
indices = indices[:, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= skv)
indices[invalid_indices_mask] = 0
qs = q # [s_q, h_q, d_qk]
kvs = kv[:, 0, :][indices].view(sq, topk, dqk) # [s_q, topk, d_qk]
attn_score = (qs @ kvs.transpose(1, 2)).float() # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
attn_score *= sm_scale * math.log2(math.e)
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score.to(q.dtype) @ kvs[:, :, :d_v]
return (result, lse)
class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
topk_indice_buffer: torch.Tensor | None = None,
indexer: "Indexer | None" = None,
**mla_args,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
def _forward_bf16_kv(
self,
q: torch.Tensor, # [sq, heads, d_qk]
kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
topk_indices: torch.Tensor, # [sq, topk]
attn_metadata: ROCMAiterMLASparseMetadata,
) -> torch.Tensor:
num_tokens = q.shape[0]
output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
seq_len = (topk_indices != -1).sum(dim=-1)
torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])
fetch_id_to_ragged_triton(
topk_indices,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.topk_tokens,
)
rocm_aiter_ops.mla_decode_fwd(
q,
kv_c_and_k_pe_cache,
output,
self.scale,
attn_metadata.qo_indptr,
1,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len,
)
return output[:, : self.num_heads, :]
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: ROCMAiterMLASparseMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
# Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
num_actual_toks = q.shape[0]
# Get topk indices
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
topk_indices_global = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)
attn_out = self._forward_bf16_kv(
q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
)
return attn_out, None

View File

@@ -0,0 +1,191 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for sparse MLA backends."""
import torch
from vllm.triton_utils import tl, triton
# Kernel with prefill workspace support and valid count tracking
@triton.jit
def _convert_req_index_to_global_index_kernel(
req_id_ptr, # int32 [num_tokens]
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
valid_count_ptr, # int32 [num_tokens] - output valid count per row
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
# shapes (compile-time where possible)
max_num_blocks_per_req: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr, # tile width along columns
HAS_PREFILL: tl.constexpr,
COUNT_VALID: tl.constexpr, # whether to count valid indices
# strides (in elements)
bt_stride0,
bt_stride1,
ti_stride0,
ti_stride1,
out_stride0,
out_stride1,
):
# program_id(0) -> token_id (row)
# program_id(1) -> tile index along columns
token_id = tl.program_id(0)
tile_id = tl.program_id(1)
# Each program covers BLOCK_N consecutive columns
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
# Load request id for this token (no mask: grid is exact)
req = tl.load(req_id_ptr + token_id)
# Load token indices for this tile
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
tok = tl.load(ti_ptr) # int32
# Only token == -1 should propagate as -1
is_invalid_tok = tok < 0
is_prefill = False
if HAS_PREFILL:
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
is_prefill = prefill_req_id >= 0
# Compute block id and in-block offset
block_id = tok // BLOCK_SIZE
inblock_off = tok % BLOCK_SIZE
# Guard block_table access
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
is_invalid_tok |= ~valid_block
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
out_val = base * BLOCK_SIZE + inblock_off
# Override with prefill output if prefill is enabled
if HAS_PREFILL:
workspace_start = tl.load(
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
)
prefill_out = workspace_start + tok
out_val = tl.where(is_prefill, prefill_out, out_val)
out_val = tl.where(is_invalid_tok, -1, out_val)
# Store results
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
tl.store(out_ptr_ij, out_val)
# Count valid indices in this tile and atomically add to row total
if COUNT_VALID:
tile_valid_count = tl.sum((~is_invalid_tok).to(tl.int32))
tl.atomic_add(valid_count_ptr + token_id, tile_valid_count)
def triton_convert_req_index_to_global_index(
req_id: torch.Tensor, # int32 [num_tokens]
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
BLOCK_SIZE: int = 64,
NUM_TOPK_TOKENS: int = 2048,
BLOCK_N: int = 128, # tile width along columns
HAS_PREFILL_WORKSPACE: bool = False,
prefill_workspace_request_ids: torch.Tensor | None = None,
prefill_workspace_starts: torch.Tensor | None = None,
return_valid_counts: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
out[token_id, indice_id] =
block_table[req_id[token_id],
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
+ token_indices[token_id, indice_id] % BLOCK_SIZE
Only when token_indices[token_id, indice_id] == -1 do we output -1.
For safety, we also output -1 if the derived block_id would be
out-of-bounds.
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
instead of global cache slots. prefill_workspace_request_ids and
prefill_workspace_starts must be provided.
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
prefill request index (maps to prefill_workspace_starts)
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
starts for each prefill request
When return_valid_counts is True, also returns the count of valid (non -1)
indices per row, computed during the same kernel pass (no extra overhead).
"""
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
assert token_indices.shape[1] == NUM_TOPK_TOKENS
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
)
if HAS_PREFILL_WORKSPACE:
assert prefill_workspace_request_ids is not None
assert prefill_workspace_starts is not None
assert prefill_workspace_request_ids.dtype == torch.int32
assert prefill_workspace_starts.dtype == torch.int32
num_tokens = req_id.shape[0]
max_num_blocks_per_req = block_table.shape[1]
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
# Ensure contiguous tensors on the same device
req_id_c = req_id.contiguous()
block_table_c = block_table.contiguous()
token_indices_c = token_indices.contiguous()
out = torch.empty_like(token_indices_c)
# Allocate valid count buffer if needed (must be zero-initialized for atomics)
valid_counts: torch.Tensor | None = None
if return_valid_counts:
valid_counts = torch.zeros(
num_tokens, dtype=torch.int32, device=token_indices.device
)
# Strides in elements
bt_stride0, bt_stride1 = block_table_c.stride()
ti_stride0, ti_stride1 = token_indices_c.stride()
out_stride0, out_stride1 = out.stride()
# Prepare prefill pointers
if HAS_PREFILL_WORKSPACE:
assert prefill_workspace_request_ids is not None # for mypy
assert prefill_workspace_starts is not None # for mypy
assert prefill_workspace_request_ids.is_contiguous()
assert prefill_workspace_starts.is_contiguous()
# Exact 2D grid: tokens × column tiles
grid = (num_tokens, tiles_per_row)
_convert_req_index_to_global_index_kernel[grid](
req_id_c,
block_table_c,
token_indices_c,
out,
valid_counts,
prefill_workspace_request_ids,
prefill_workspace_starts,
# shapes / constexprs
max_num_blocks_per_req,
BLOCK_SIZE,
BLOCK_N,
HAS_PREFILL_WORKSPACE,
return_valid_counts,
# strides
bt_stride0,
bt_stride1,
ti_stride0,
ti_stride1,
out_stride0,
out_stride1,
)
if return_valid_counts:
assert valid_counts is not None
return out, valid_counts
return out

View File

@@ -0,0 +1,210 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
import ixformer.inference.functions as ixf_ops
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.parallel_state import get_dcp_group
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
# "auto",
# "bfloat16",
# ]
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
@staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl"
)
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported"
)
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
return super()._flash_attn_varlen_diff_headdims(
q,
k,
v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
def forward_mqa(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
# layer: AttentionLayer,
k_c_normed: torch.Tensor |None = None,
k_pe: torch.Tensor |None = None,
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
decode_meta = attn_metadata.decode
q_nope = self._k_up_proj(q_nope)
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
B = q_nope.shape[0]
if self.dcp_world_size > 1:
q = torch.cat([q_nope, q_pe], dim=-1)
q = get_dcp_group().all_gather(q, dim=1)
o = torch.empty(B,
q.shape[1],
self.kv_lora_rank,
dtype=q_nope.dtype,
device=q_nope.device)
if envs.VLLM_USE_INT8_MLA:
q_int8, q_scale = ops.quant_kv(q)
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
o,
q_int8,
q_scale,
kv_c_and_k_pe_cache,
kv_c_and_k_pe_cache_scale,
self.scale,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens,
attn_metadata.decode.max_decode_seq_len,
return_softmax_lse=True
)
else:
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
output=o,
query=q,
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
return_softmax_lse=True)
return attn_out, softmax_lse
o = torch.empty(B,
self.num_heads,
self.kv_lora_rank,
dtype=q_nope.dtype,
device=q_nope.device)
if envs.VLLM_USE_INT8_MLA:
q = torch.cat([q_nope, q_pe], dim=-1)
q_int8, q_scale = ops.quant_kv(q)
ixf_ops.vllm_paged_attention_mla_int8(
o,
q_int8,
q_scale,
kv_c_and_k_pe_cache,
kv_c_and_k_pe_cache_scale,
self.scale,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens,
attn_metadata.decode.max_decode_seq_len,
attn_metadata.decode.use_cuda_graph
)
else:
# fused q concat & cache write
ixf_ops.vllm_paged_attention_mla_fused(
output=o,
q_nope=q_nope,
q_pe=q_pe.contiguous(),
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
k_c_normed=k_c_normed,
k_pe=k_pe,
use_cuda_graph=decode_meta.use_cuda_graph
)
return self._v_up_proj(o), None

View File

@@ -0,0 +1,261 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry"""
from collections.abc import Callable
from enum import Enum, EnumMeta
from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
from vllm.v1.attention.backend import AttentionBackend
logger = init_logger(__name__)
class _AttentionBackendEnumMeta(EnumMeta):
"""Metaclass for AttentionBackendEnum to provide better error messages."""
def __getitem__(cls, name: str):
"""Get backend by name with helpful error messages."""
try:
return super().__getitem__(name)
except KeyError:
members = cast("dict[str, Enum]", cls.__members__).keys()
valid_backends = ", ".join(members)
raise ValueError(
f"Unknown attention backend: '{name}'. "
f"Valid options are: {valid_backends}"
) from None
class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported attention backends.
The enum value is the default class path, but this can be overridden
at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
"""
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
FLASH_ATTN_DIFFKV = (
"vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend"
)
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_TRITON_MLA = (
"vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend"
)
ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
ROCM_AITER_MLA_SPARSE = (
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
)
TORCH_SDPA = "" # this tag is only used for ViT
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
FLASHINFER_MLA_SPARSE = (
"vllm.v1.attention.backends.mla.flashinfer_mla_sparse."
"FlashInferMLASparseBackend"
)
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
FLASHMLA_SPARSE = (
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
)
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
ROCM_AITER_UNIFIED_ATTN = (
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM = None
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Returns:
The fully qualified class path string
Raises:
ValueError: If Backend.CUSTOM is used without being registered
"""
path = _ATTN_OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns:
The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If Backend.CUSTOM is used without being registered
"""
return resolve_obj_by_qualname(self.get_path())
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
Returns:
True if the backend has a registered override
"""
return self in _ATTN_OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_ATTN_OVERRIDES.pop(self, None)
class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported mamba attention backends.
The enum value is the default class path, but this can be overridden
at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
"""
MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM = None
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Returns:
The fully qualified class path string
Raises:
ValueError: If Backend.CUSTOM is used without being registered
"""
path = _MAMBA_ATTN_OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns:
The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If Backend.CUSTOM is used without being registered
"""
return resolve_obj_by_qualname(self.get_path())
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
Returns:
True if the backend has a registered override
"""
return self in _MAMBA_ATTN_OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_MAMBA_ATTN_OVERRIDES.pop(self, None)
MAMBA_TYPE_TO_BACKEND_MAP = {
"mamba1": MambaAttentionBackendEnum.MAMBA1.name,
"mamba2": MambaAttentionBackendEnum.MAMBA2.name,
"short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
"linear_attention": MambaAttentionBackendEnum.LINEAR.name,
"gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
"custom": MambaAttentionBackendEnum.CUSTOM.name,
}
_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
def register_backend(
backend: AttentionBackendEnum | MambaAttentionBackendEnum,
class_path: str | None = None,
is_mamba: bool = False,
) -> Callable[[type], type]:
"""Register or override a backend implementation.
Args:
backend: The AttentionBackendEnum member to register
class_path: Optional class path. If not provided and used as
decorator, will be auto-generated from the class.
Returns:
Decorator function if class_path is None, otherwise a no-op
Examples:
# Override an existing attention backend
@register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
...
# Override an existing mamba attention backend
@register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
class MyCustomMambaAttn:
...
# Register a custom third-party attention backend
@register_backend(AttentionBackendEnum.CUSTOM)
class MyCustomBackend:
...
# Direct registration
register_backend(
AttentionBackendEnum.CUSTOM,
"my.module.MyCustomBackend"
)
"""
def decorator(cls: type) -> type:
if is_mamba:
_MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
else:
_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
return cls
if class_path is not None:
if is_mamba:
_MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
else:
_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
return lambda x: x
return decorator

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,249 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.v1.attention.backend import AttentionLayer, AttentionType
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend,
RocmAttentionImpl,
RocmAttentionMetadataBuilder,
)
logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "ROCM_AITER_UNIFIED_ATTN"
@staticmethod
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
return RocmAiterUnifiedAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
sinks,
)
logger.info_once(
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
)
from aiter.ops.triton.unified_attention import unified_attention
self.unified_attention = unified_attention
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for RocmAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (
cu_seqlens_q.shape[0] - 1,
key.shape[1] if key is not None else self.num_kv_heads,
)
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
)
return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)

View File

@@ -0,0 +1,461 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
)
from vllm.v1.attention.ops.paged_attn import PagedAttention
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
@dataclass
class RocmAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config
)
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> RocmAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
# Here we set the query start locs to 0. This is to
# cover up an invalid memory access in the prefix_prefil kernel
# that we run into during graph capture (#25985)
common_attn_metadata.query_start_loc.zero_()
common_attn_metadata.query_start_loc_cpu.zero_()
return attn_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> RocmAttentionMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device=self.device
)
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
attn_metadata = RocmAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
)
return attn_metadata
class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# ROCM paged attention kernel only supports block sizes 16 and 32
# due to shared memory (LDS) constraints on AMD GPUs.
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.
# However, The limitations in [16, 32] are reasonable for a native C++ kernel,
# but vLLM should allow support for non-standard sizes via the Triton path,
# as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380,
# where the Triton kernel under rocm_atten does not support inference
# for a non-standard qwen3-next model with a block_size of 544.
# We have fixed the Triton kernel so that the standard model uses the original
# bit-addressing logic, while the non-standard model
# uses our optimized kernel logic.
return [16, 32, 544]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
if not cls.supports_head_size(head_size):
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "ROCM_ATTN"
@staticmethod
def get_impl_cls() -> type["RocmAttentionImpl"]:
return RocmAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
class RocmAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
RocmAttentionBackend.validate_head_size(head_size)
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for RocmAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}."
)
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for RocmAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens] if key is not None else None,
value=value[:num_actual_tokens] if value is not None else None,
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
# Reshape the input keys and values and store them in the cache.
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache,
layer.num_kv_heads, # type: ignore[attr-defined]
layer.head_size, # type: ignore[attr-defined]
)
flash_layout = False
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)

View File

@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
class ShortConvAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "SHORT_CONV_ATTN"
@staticmethod
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
return ShortConvAttentionMetadataBuilder
@dataclass
class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):
pass
class ShortConvAttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
):
metadata_cls = ShortConvAttentionMetadata

View File

@@ -0,0 +1,430 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with TreeAttention."""
import ast
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
)
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "TREE_ATTN"
@staticmethod
def get_impl_cls() -> type["TreeAttentionImpl"]:
return TreeAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
return TreeAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class TreeAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
tree_attn_bias: torch.Tensor | None = None
# Cached Prefill/decode metadata.
_cached_prefill_metadata: "TreeAttentionMetadata | None" = None
_cached_decode_metadata: "TreeAttentionMetadata | None" = None
@property
def prefill_metadata(self) -> "TreeAttentionMetadata | None":
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes :]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes :]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes :],
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> "TreeAttentionMetadata | None":
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc[: self.num_decodes + 1]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[: self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc,
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[: self.num_decodes],
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
tree_attn_bias=self.tree_attn_bias,
)
return self._cached_decode_metadata
class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config
spec_token_tree: str | None = None
if spec := spec_config:
spec_token_tree = spec.speculative_token_tree
tree_choices: list[tuple[int, ...]] = (
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
)
# Construct the tree attention bias.
depth_counts = _get_depth_counts(tree_choices)
self.tree_attn_bias = _prepare_tree_attn_bias(
tree_choices,
depth_counts,
dtype=torch.float32,
device=device,
)
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TreeAttentionMetadata:
decode_threshold = self.tree_attn_bias.shape[0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=decode_threshold
)
)
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
return TreeAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
tree_attn_bias=self.tree_attn_bias,
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> TreeAttentionMetadata:
# Cache the original tree attention bias.
orig_tree_attn_bias = self.tree_attn_bias
if draft_index == 0:
# Use prefill for drafting at the root level.
self.tree_attn_bias = torch.empty(0)
else:
# Slice the tree attention bias for drafting. Exclude
# the root level.
start, end = 1, 1 + common_attn_metadata.max_query_len
self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous()
# Build attention bias.
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
# Reset the tree attention bias to the original value.
self.tree_attn_bias = orig_tree_attn_bias
return attn_metadata
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
# Count the number of choices at each depth of the tree.
depth_counts = []
prev_depth = 0
for path in sorted_tree_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
return depth_counts
def _prepare_tree_attn_bias(
sorted_tree_choices: list[tuple[int, ...]],
depth_counts: list[int],
dtype: torch.dtype | None,
device: torch.device | None,
) -> torch.Tensor:
# +1 comes from the additional root node.
tree_len = len(sorted_tree_choices) + 1
tree_attn_mask = torch.full(
(tree_len, tree_len), -torch.inf, device=device, dtype=dtype
)
# Set diagonal to all zeros. Each token should
# attend to itself.
mask_val = 0
for i in range(tree_len):
tree_attn_mask[i, i] = mask_val
# Set root to all zeros. All tokens attend to it.
tree_attn_mask[:, 0] = mask_val
# Set all ancestors to zeros.
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_tree_choice = sorted_tree_choices[start + j]
# Retrieve ancestor position.
if len(cur_tree_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_tree_choice) - 1):
ancestor_idx.append(
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
)
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
start += depth_counts[i]
return tree_attn_mask
class TreeAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TreeAttentionImpl."
)
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TreeAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with TreeAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for TreeAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1])
if prefill_meta := attn_metadata.prefill_metadata:
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_query_len,
seqused_k=decode_meta.seq_lens,
max_seqlen_k=decode_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
qq_bias=decode_meta.tree_attn_bias,
window_size=self.sliding_window,
block_table=decode_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

View File

@@ -0,0 +1,638 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""High-Performance Triton-only Attention layer."""
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
# constants
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments
@dataclass
class TritonAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
seq_threshold_3D: int
num_par_softmax_segments: int
softmax_segm_output: torch.Tensor
softmax_segm_max: torch.Tensor
softmax_segm_expsum: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
Empty ranges have start==end==0, which kernel skips via is_valid check.
"""
# TODO(Isotr0py): Move to model runner's attention metadata
# preparation to avoid duplicate computation.
if self.mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
]
# Return None if all ranges are trivial (only (0,0) placeholders)
if all(r == [(0, 0)] for r in range_lists):
return None
# Create 2D tensors with shape (num_ranges, 2) for each sequence
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(
range_tensors, layout=torch.jagged
).to_padded_tensor(0)
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config
)
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
# Check if CUDA Graphs are enabled for decode
self.decode_cudagraph_enabled = (
self.vllm_config.compilation_config.cudagraph_mode
in (
CUDAGraphMode.FULL_AND_PIECEWISE,
CUDAGraphMode.FULL_DECODE_ONLY,
CUDAGraphMode.FULL,
)
)
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
# A lower bound for num_q_blocks is the number of sequences.
# To ensure the minimum launch grid size is achieved, the number of sequences
# must be at least equal to the threshold below.
# If this threshold is not reached (i.e., the batch size is not large enough),
# the 3D kernel will be selected instead.
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
# Modify the threshold if needed.
if self.decode_cudagraph_enabled:
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified."
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
self.seq_threshold_3D = min(
capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
headdim_padded = next_power_of_2(self.headdim)
self.softmax_segm_output = torch.empty(
(
self.seq_threshold_3D,
self.num_heads_q,
self.num_par_softmax_segments,
headdim_padded,
),
dtype=torch.float32,
device=device,
)
self.softmax_segm_max = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
self.softmax_segm_expsum = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TritonAttentionMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device=self.device
)
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
attn_metadata = TritonAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
seq_threshold_3D=self.seq_threshold_3D,
num_par_softmax_segments=self.num_par_softmax_segments,
softmax_segm_output=self.softmax_segm_output,
softmax_segm_max=self.softmax_segm_max,
softmax_segm_expsum=self.softmax_segm_expsum,
)
return attn_metadata
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "TRITON_ATTN"
@staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
if include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4, 5)
# (num_blocks, 2, block_size, num_kv_heads, head_size)
return (0, 1, 2, 3, 4)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@classmethod
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""TritonAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@classmethod
def supports_alibi_sqrt(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
use_alibi_sqrt: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}."
)
self.use_alibi_sqrt = use_alibi_sqrt
self.supports_quant_query_input = current_platform.is_cuda()
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Paged Attention impl. in Triton.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for TritonAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if self.kv_cache_dtype.startswith("fp8"):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
seq_threshold_3D = attn_metadata.seq_threshold_3D
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
softmax_segm_output = attn_metadata.softmax_segm_output
softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=self.use_alibi_sqrt,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
mm_prefix_range=mm_prefix_range_tensor,
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
max_query_len = attn_metadata.max_query_len
# Call flash attention directly on Q, K, V tensors
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False, # Encoder attention is bidirectional
softmax_scale=self.scale,
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)
return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
# Reshape the input keys and values and store them in the cache.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)

View File

@@ -0,0 +1,866 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass
from typing import (
TYPE_CHECKING,
Any,
Literal,
Protocol,
get_args,
)
import numpy as np
import torch
from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
CommonAttentionMetadata,
subclass_attention_backend,
)
logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
PAD_SLOT_ID = -1
def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
@functools.lru_cache
def get_kv_cache_layout():
# Format specified by the code.
global _KV_CACHE_LAYOUT_OVERRIDE
cache_layout: Literal["NHD", "HND"] | None = None
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
logger.info_once(
"`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
"Setting KV cache layout to %s.",
cache_layout,
)
return cache_layout
# Format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
# When neither the user nor the override specified a layout, get default
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
else:
assert is_valid_kv_cache_layout(cache_layout)
logger.info_once(
"`VLLM_KV_CACHE_LAYOUT` environment variable "
"detected. Setting KV cache layout to %s.",
cache_layout,
)
return cache_layout
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
global _KV_CACHE_LAYOUT_OVERRIDE
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters. Should not be used for
trtllm-gen backend since it supports different values for the following
hyperparameters.
"""
window_left: int
logits_soft_cap: float | None
sm_scale: float
has_sinks: bool = False
# has same params for all layers
has_same_window_lefts: bool | None = field(default=None, compare=False)
has_same_all_params: bool | None = field(default=None, compare=False)
def get_per_layer_parameters(
vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
) -> dict[str, PerLayerParameters]:
"""
Scan layers in `layer_names` and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, cls_)
# Infer hyperparameters from the attention layer
window_size = getattr(impl, "sliding_window", None)
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
sm_scale = impl.scale
has_sinks = getattr(impl, "sinks", None) is not None
per_layer_params[key] = PerLayerParameters(
window_left, logits_soft_cap, sm_scale, has_sinks
)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
"""
Currently, FlashInfer backend other than trtllm-gen
only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
global_params.has_same_window_lefts = all(
params.window_left == global_params.window_left for params in param_sets
)
global_params.has_same_all_params = all(
params == global_params for params in param_sets
)
return global_params
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0,
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
block_table = common_attn_metadata.block_table_tensor
device = common_attn_metadata.query_start_loc.device
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
cu_seqlens_q_local[0] = 0
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
)
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // block_size
assert attn_chunk_size % block_size == 0, (
f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}"
)
pages_per_local_batch = attn_chunk_size // block_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming block_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices = block_starts[:, None] + np.arange(
pages_per_local_batch, dtype=np.int32
)
block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
batch_indices = np.repeat(
np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch,
)
# NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
# regression when using numpy arrays (batch and block indices) to index into
# torch tensor (block_table). As a workaround, convert numpy arrays to torch
# tensor first, which recovers perf.
batch_indices_torch = torch.from_numpy(batch_indices)
block_indices_torch = torch.from_numpy(block_indices)
# Save as a lambda so we can return this for update_block_table
make_block_table = lambda block_table: block_table[
batch_indices_torch, block_indices_torch
].view(virtual_batches, -1)
block_table_local = make_block_table(block_table)
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
max_seq_len = int(seq_lens_cpu.max())
return CommonAttentionMetadata(
query_start_loc_cpu=query_start_loc_cpu,
query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
num_reqs=len(seq_lens_cpu),
num_actual_tokens=common_attn_metadata.num_actual_tokens,
max_query_len=seqlens_q_local.max(),
max_seq_len=max_seq_len,
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
), make_block_table
def make_kv_sharing_fast_prefill_common_attn_metadata(
common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
if common_attn_metadata.max_query_len == 1:
# All requests are decode (assume 1 token for now)
# Skip computing fast prefill path
return common_attn_metadata
assert common_attn_metadata.logits_indices_padded is not None
assert common_attn_metadata.num_logits_indices is not None
logits_indices_padded = common_attn_metadata.logits_indices_padded
num_logits_indices = common_attn_metadata.num_logits_indices
# Get rid of CUDAGraph padding, if any
logits_indices = logits_indices_padded[:num_logits_indices]
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
# Example inputs
# num_reqs: 3
# generation_indices: [14, 18, 19, 27]
# query_start_loc: [0, 15, 20, 28]
# seq_lens: [41, 31, 40]
# Find how many decode indices belong to each request
# request_ids: [0, 1, 1, 2]
request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
# Figure out how many tokens are in each request
# num_decode_tokens: [1, 2, 1]
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
# Calculate new query_start_loc with tokens in generation_indices
# decode_query_start_loc: [0, 1, 3, 4]
decode_query_start_loc = torch.empty(
num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype
)
decode_query_start_loc[0] = 0
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
decode_max_query_len = int(num_decode_tokens.max().item())
total_num_decode_tokens = int(num_decode_tokens.sum().item())
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=decode_query_start_loc,
query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
seq_lens=common_attn_metadata.seq_lens,
num_reqs=num_reqs,
num_actual_tokens=total_num_decode_tokens,
max_query_len=decode_max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
return common_attn_metadata
def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
Returns:
num_decodes: The number of decode requests.
num_extends: The number of extend requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_extend_tokens: The number of tokens in the extend requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu
if max_query_len <= decode_threshold:
return num_reqs, 0, 0, num_tokens, 0, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill_or_extend = query_lens > decode_threshold
is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
first_prefill = is_prefill.int().argmax(dim=-1).item()
num_decodes = first_extend
num_decode_tokens = query_start_loc[first_extend].item()
if not torch.any(is_prefill_or_extend):
return (num_decodes, 0, 0, num_decode_tokens, 0, 0)
num_prefills_or_extends = num_reqs - num_decodes
num_prefill_or_extend_tokens = num_tokens - num_decode_tokens
if not torch.any(is_prefill):
return (
num_decodes,
num_prefills_or_extends,
0,
num_decode_tokens,
num_prefill_or_extend_tokens,
0,
)
num_extends = first_prefill - num_decodes
num_prefills = num_reqs - first_prefill
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens
return (
num_decodes,
num_extends,
num_prefills,
num_decode_tokens,
num_extend_tokens,
num_prefill_tokens,
)
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
require_uniform: bool = False,
) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
require_uniform: If True, requires that all decode requests have the
same query length. When set, some queries may be considered prefills
even if they are <= decode_threshold, in order to ensure uniformity.
Returns:
num_decodes: The number of decode requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold and (
not require_uniform or decode_threshold <= 1
):
return num_reqs, 0, num_tokens, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
if query_lens[0].item() > decode_threshold:
# first request is not decode, so no decode requests
return 0, num_reqs, 0, num_tokens
if require_uniform:
# check if we are in a padded uniform batch; this is used for full-CGs, some
# requests may have a query length of 0 but since they are padding its fine
# to treat them as decodes (ensures num_decodes matches the captured size)
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
return num_reqs, 0, num_tokens, 0 # all decodes
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item()
num_prefill_tokens = num_tokens - num_decode_tokens
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def split_prefill_chunks(
seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
) -> list[tuple[int, int]]:
"""
Split the prefill requests into chunks such that the total sequence length
of each chunk is less than or equal to the workspace size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
workspace_size: The maximum workspace size (in tokens) per chunk.
request_offset: The offset to add to the request indices.
Returns:
A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
"""
chunk_bounds = []
i, n = 0, len(seq_lens_cpu)
assert torch.all(seq_lens_cpu <= workspace_size).item()
while i < n:
start, chunk_total = i, 0
while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
chunk_total += s
i += 1
chunk_bounds.append((start + request_offset, i + request_offset))
return chunk_bounds
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch into decode → extend → prefill order
# where:
# decode: request with num_scheduled_tokens <= decode_threshold
# extend: non-decode request with existing context
# prefill: non-decode request with no existing context
# NOTE for now we loosely use "decode" to mean requests where attention is
# likely memory-bound and "prefill" to mean requests where attention is
# likely compute-bound,
num_reqs = len(input_batch.req_ids)
num_scheduled_tokens = [
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
]
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
is_prefill = num_computed_tokens_np == 0
is_decode = (num_scheduled_tokens_np <= decode_threshold) & (~is_prefill)
is_extend = (num_scheduled_tokens_np > decode_threshold) & (~is_prefill)
# Desired order: decode → extend → prefill
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
req_regions[is_extend] = 1
req_regions[is_prefill] = 2
num_decodes = int(is_decode.sum())
num_extends = int(is_extend.sum())
target_regions = np.zeros(num_reqs, dtype=np.int32)
target_regions[num_decodes : num_decodes + num_extends] = 1
target_regions[num_decodes + num_extends :] = 2
needs_swap = req_regions != target_regions
if not needs_swap.any():
return False
# Extract indices that need swapping and sort by target region
orig_indices = np.where(needs_swap)[0]
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
src_indices = orig_indices[sorted_order]
src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
for src in src_dest_map:
dst = src_dest_map[src]
while src != dst:
input_batch.swap_states(src, dst)
# Mark dst as done by updating its destination to itself
next_dst = src_dest_map.get(dst, dst)
src_dest_map[dst] = dst
dst = next_dst
return True
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
"""
Reshapes the query tensor for the specified batch size, so that
it has shape (batch_size, seq_len, num_heads, head_dim).
"""
assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
total_tokens = query.shape[0]
num_heads = query.shape[1]
head_dim = query.shape[2]
assert total_tokens % batch_size == 0, (
f"{total_tokens=} is not divisible by {batch_size=}"
)
seq_len = total_tokens // batch_size
return query.view(batch_size, seq_len, num_heads, head_dim)
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
"""
Reshapes the attention output tensor, so that
the batch_size and seq_len dimensions are combined.
"""
if attn_output.dim() == 3:
# Already in the correct shape
return attn_output
assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D"
total_tokens = attn_output.shape[0] * attn_output.shape[1]
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
fields: list[tuple[str, Any, Any]],
) -> Any:
"""
Return a new subclass of `metadata_cls` with additional fields
"""
name: str = name_prefix + metadata_cls.__name__ # type: ignore
Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
return Wrapped
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int | None = None
def create_fast_prefill_custom_backend(
prefix: str,
underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]:
underlying_builder = underlying_attn_backend.get_builder_cls()
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_common_attn_metadata = (
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
)
metadata = super().build(
common_prefix_len, new_common_attn_metadata, fast_build
)
class KVSharingFastPrefillAttentionMetadata(
metadata.__class__, # type: ignore
KVSharingFastPrefillMetadata,
):
def __init__(self, metadata, common_attn_metadata):
# Shallow copy all fields in metadata cls
for _field in fields(metadata.__class__):
setattr(self, _field.name, getattr(metadata, _field.name))
self.logits_indices_padded = (
common_attn_metadata.logits_indices_padded
)
self.num_logits_indices = common_attn_metadata.num_logits_indices
return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=FastPrefillAttentionBuilder,
)
return attn_backend
def compute_causal_conv1d_metadata(
query_start_loc_p_cpu: torch.Tensor,
*,
device: torch.device,
):
# Needed for causal_conv1d. Use the CPU query_start_loc to avoid DtoH sync.
assert query_start_loc_p_cpu.device.type == "cpu"
seqlens = query_start_loc_p_cpu.diff()
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]["nums"] = nums
nums_dict[BLOCK_M]["tot"] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]["mlist"] = mlist
mlist_len = len(nums_dict[BLOCK_M]["mlist"])
nums_dict[BLOCK_M]["mlist_len"] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]["offsetlist"] = offsetlist
if batch_ptr is None:
# Update default value after class definition
batch_ptr = torch.full(
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
)
token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
)
else:
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS
).fill_(PAD_SLOT_ID)
batch_ptr[0:mlist_len].copy_(mlist, non_blocking=True)
token_chunk_offset_ptr[ # type: ignore
0:mlist_len
].copy_(offsetlist, non_blocking=True)
nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
return nums_dict, batch_ptr, token_chunk_offset_ptr
def get_dcp_local_seq_lens(
seq_lens: torch.Tensor,
dcp_size: int = 1,
dcp_rank: int | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each dcp rank.
Only consider dcp now, we can extend the case of cp based on this.
"""
num_requests = seq_lens.size(0)
if dcp_rank is None:
rank_offsets = (
torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device)
.unsqueeze(0)
.repeat(num_requests, 1)
)
else:
rank_offsets = torch.tensor(
[[dcp_rank]], dtype=torch.int32, device=seq_lens.device
)
seq_lens_tiled = (
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
)
base = (
seq_lens_tiled
// cp_kv_cache_interleave_size
// dcp_size
* cp_kv_cache_interleave_size
)
remainder = seq_lens_tiled - base * dcp_size
remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1)
def mamba_get_block_table_tensor(
block_table: torch.Tensor,
seq_lens: torch.Tensor,
kv_cache_spec: KVCacheSpec,
mamba_cache_mode: str,
) -> torch.Tensor:
"""
Get the block table tensor for mamba kernels from the input
common_attn_metadata.block_table_tensor given different mamba cache modes.
- "all": input (#requests, cdiv(max_model_len, block_size));
output (#requests, cdiv(max_model_len, block_size)).
- "none": input (#requests, 1 + num_speculative_blocks);
output (#requests, 1 + num_speculative_blocks).
- "align": input (#requests, cdiv(max_model_len, block_size));
output (#requests, 1 + num_speculative_blocks), which are the last
1 + num_speculative_blocks of each request.
"""
if mamba_cache_mode in ("all", "none"):
return block_table
else:
assert isinstance(kv_cache_spec, MambaSpec)
# NOTE: For 0-length requests in CUDA graph, use a start_index of 0
# to handle the invalid block table.
start_indices = torch.clamp(
(seq_lens - 1) // kv_cache_spec.block_size,
min=0,
)
# Use int32 for arithmetic to avoid dtype promotion overhead,
# then convert to int64 for gather (which requires Long indices)
offsets = torch.arange(
1 + kv_cache_spec.num_speculative_blocks,
device=block_table.device,
dtype=torch.int32,
)
indices_to_gather = (start_indices.unsqueeze(1) + offsets).to(torch.int64)
return torch.gather(block_table, 1, indices_to_gather)