Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

View File

@@ -0,0 +1,497 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionType,
is_quantized_kv_cache,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
logger = init_logger(__name__)
_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM)
class CPUAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "CPU_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder,
encoder-only and encoder-decoder attention."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl
@staticmethod
def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]:
return CPUAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return 2, num_blocks, num_kv_heads, block_size, head_size
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class CPUAttentionMetadata:
isa: str
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
scheduler_metadata: torch.Tensor | None
causal: bool = True
# can be removed after deprecate sdpa
use_sdpa_prefill: bool = False
num_decode_tokens: int = 0
sdpa_attn_masks: list[torch.Tensor | None] | None = None
sdpa_start_loc: torch.Tensor | None = None
class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.use_sdpa_prefill = False
reorder_batch_threshold = None
if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH:
# in this case, decode seqs are reordered to the front of prefill seqs
# to split decode and prefill. Then use SDPA for prefill and
# cpu_attention_with_kv_cache for decode
reorder_batch_threshold = 1
self.use_sdpa_prefill = True
self._init_reorder_batch_threshold(reorder_batch_threshold, False)
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
parallel_config = vllm_config.parallel_config
self.num_kv_heads = vllm_config.model_config.get_num_kv_heads(parallel_config)
self.num_heads = vllm_config.model_config.get_num_attention_heads(
parallel_config
)
self.head_dim = kv_cache_spec.head_size
self.dtype = vllm_config.model_config.dtype
self.window_size = getattr(kv_cache_spec, "sliding_window", -1)
if self.window_size is None:
self.window_size = -1
self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size)
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CPUAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = False if self.is_cross_attention else common_attn_metadata.causal
sdpa_start_loc = query_start_loc
num_decode_tokens = 0
if self.use_sdpa_prefill and causal:
# Decoder, need reorder and truncate
assert self.reorder_batch_threshold
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True,
)
)
num_reqs = num_decodes
sdpa_start_loc = sdpa_start_loc[num_decodes:] - num_decode_tokens
seq_lens = seq_lens[:num_decodes]
query_start_loc = query_start_loc[: num_decodes + 1]
block_table_tensor = block_table_tensor[:num_decodes]
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
num_reqs=num_reqs,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
seq_lens=seq_lens,
dtype=self.dtype,
query_start_loc=query_start_loc,
causal=causal,
sliding_window_size=self.window_size,
isa=self.isa,
enable_kv_split=True,
)
attn_metadata = CPUAttentionMetadata(
isa=self.isa,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
scheduler_metadata=sheduler_metadata,
causal=causal,
use_sdpa_prefill=self.use_sdpa_prefill,
num_decode_tokens=num_decode_tokens,
sdpa_start_loc=sdpa_start_loc,
)
return attn_metadata
class CPUAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
if logits_soft_cap is not None and attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
logger.warning_once(
"CPU_ATTN does not support logits softcap for"
" ENCODER and ENCODER_ONLY, outputs may be slightly off"
)
if logits_soft_cap is None:
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type == AttentionType.ENCODER_ONLY:
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN")
self.attn_type = attn_type
self.sinks = sinks
if self.sinks is not None:
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer"
)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: CPUAttentionMetadata | None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass for CPU attention backend.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for CPUAttentionBackendImpl"
)
# For warming-up
if attn_metadata is None:
return output
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
return self._run_sdpa_forward(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
self.attn_type,
)
# For decoder and cross-attention, use KV cache, size are
# [num_blocks, num_kv_heads, block_size, head_size]
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
ops.cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.isa,
)
if attn_metadata.use_sdpa_prefill:
assert self.sinks is None, "Attention sink is unsupported in SDPA prefill"
num_decode_tokens = attn_metadata.num_decode_tokens
self._run_sdpa_forward(
query[num_decode_tokens:num_actual_tokens],
key[num_decode_tokens:num_actual_tokens],
value[num_decode_tokens:num_actual_tokens],
output[num_decode_tokens:num_actual_tokens],
attn_metadata,
self.attn_type,
)
num_actual_tokens = num_decode_tokens
if num_actual_tokens > 0:
ops.cpu_attention_with_kv_cache(
query=query[:num_actual_tokens],
key_cache=key_cache,
value_cache=value_cache,
output=output[:num_actual_tokens], # type: ignore
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes, # type: ignore
sliding_window=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=attn_metadata.scheduler_metadata,
s_aux=self.sinks,
)
return output
def _run_sdpa_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: CPUAttentionMetadata,
attn_type: str,
) -> torch.Tensor:
attn_masks = attn_metadata.sdpa_attn_masks
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.sdpa_start_loc,
)
elif self.sliding_window[0] != -1 or self.sliding_window[1] != -1:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.sdpa_start_loc,
self.sliding_window[0],
self.sliding_window[1],
query.dtype,
)
else:
attn_masks = [None] * (attn_metadata.sdpa_start_loc.size(0) - 1) # type: ignore
attn_metadata.sdpa_attn_masks = attn_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
causal_attn = attn_type == AttentionType.DECODER
sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
for i in range(len(attn_masks)):
mask = attn_masks[i]
start_q = sdpa_start_loc[i]
end_q = sdpa_start_loc[i + 1]
sub_out = (
torch.nn.functional.scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_q:end_q, :],
value[None, :, start_q:end_q, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and mask is None,
scale=self.scale,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = sub_out
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
sdpa_start_loc: torch.Tensor,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
seq_num = sdpa_start_loc.size(0) - 1
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
for i in range(seq_num):
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
bias = torch.arange(seq_len, dtype=dtype) # type: ignore
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = (
torch.empty((1, seq_len, seq_len), dtype=bias.dtype) # type: ignore
.fill_(-torch.inf)
.triu_(diagonal=1)
)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
sdpa_start_loc: torch.Tensor,
left_window_size: int,
right_window_size: int,
dtype: torch.dtype,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
seq_num = sdpa_start_loc.size(0) - 1
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
for i in range(seq_num):
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
mask = torch.full( # type: ignore
(1, seq_len, seq_len), # type: ignore
fill_value=1,
dtype=dtype,
)
if right_window_size != -1:
mask = torch.tril(mask, diagonal=right_window_size)
if left_window_size != -1:
mask = torch.triu(mask, diagonal=-left_window_size)
mask = torch.log(mask)
attn_biases.append(mask)
return attn_biases
def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
supports_amx = torch._C._cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx"
elif block_size % 32 == 0:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
return "neon"
else:
return "vec"
else:
return "vec16"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,375 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: torch.Tensor | None = None
spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,]
non_spec_query_start_loc: torch.Tensor | None = (
None # shape: [batch - num_spec_decodes + 1,]
)
spec_state_indices_tensor: torch.Tensor | None = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: torch.Tensor | None = (
None # shape: [batch - num_spec_decodes,]
)
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
spec_token_indx: torch.Tensor | None = None
non_spec_token_indx: torch.Tensor | None = None
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
self.compilation_config.max_cudagraph_capture_size,
)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)
self.spec_token_indx = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.non_spec_token_indx = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build( # type: ignore[override]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: torch.Tensor | None = None,
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (
not self.use_spec_decode
or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
.sum()
.item()
== 0
):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
else:
spec_sequence_masks = spec_sequence_masks.to(
query_start_loc.device, non_blocking=True
)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1)
)
num_spec_decode_tokens = 0
spec_token_indx = None
non_spec_token_indx = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
num_spec_decode_tokens = (
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
)
if num_prefills == 0 and num_decodes == 0:
spec_token_size = min(
num_spec_decodes * (self.num_spec + 1),
query_start_loc[-1].item(),
)
spec_token_indx = torch.arange(
spec_token_size,
dtype=torch.int32,
device=query_start_loc.device,
)
non_spec_token_indx = torch.empty(
0, dtype=torch.int32, device=query_start_loc.device
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
index = torch.argsort(spec_token_masks, stable=True)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]
spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = m.block_table_tensor[
~spec_sequence_masks, 0
]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
)
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:],
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(non_spec_query_start_loc)
)
else:
has_initial_state = None
# Prepare tensors for cudagraph
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
batch_size = m.num_actual_tokens
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True
)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert non_spec_token_indx is not None and spec_token_indx is not None
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
non_spec_token_indx, non_blocking=True
)
non_spec_token_indx = self.non_spec_token_indx[
: non_spec_token_indx.size(0)
]
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
spec_token_indx, non_blocking=True
)
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True
)
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True
)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs
):
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True
)
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
:batch_size
]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True
)
non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index]
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=m.num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
spec_sequence_masks=spec_sequence_masks,
spec_token_indx=spec_token_indx,
non_spec_token_indx=non_spec_token_indx,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (
m.num_reqs <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
), (
f"GDN only supports decode-only full CUDAGraph capture. "
f"Make sure batch size ({m.num_reqs}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
f"and number of tokens ({m.num_actual_tokens}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
)
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)

View File

@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class LinearAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
return LinearAttentionMetadataBuilder
@dataclass
class LinearAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> LinearAttentionMetadata:
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
attn_metadata = LinearAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata

View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class Mamba1AttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
return Mamba1AttentionMetadataBuilder
@dataclass
class Mamba1AttentionMetadata:
query_start_loc_p: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states_p: torch.Tensor | None
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba1AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
has_initial_states_p = None
query_start_loc_p = None
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
# We should consolidate this code
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
if num_prefills > 0:
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
return Mamba1AttentionMetadata(
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
)

View File

@@ -0,0 +1,348 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
def compute_varlen_chunk_metadata(
query_start_loc: torch.Tensor,
chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
and a physical `chunk_size`, returns three tensors on the same device:
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
logical-chunk lengths (each logical chunk never crosses a sequence or
physical-chunk boundary).
- last_chunk_indices: (B,) int32 index of the last logical chunk
for each sequence (=-1 for empty sequences).
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
chunk in order.
This is intentionally lightweight and CPU-side; it mirrors the metadata
produced by the V1 Mamba2 meta-data builder and is exported so tests
(and other callers) can avoid duplicating the logic.
"""
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
device = query_start_loc.device
qsl64 = query_start_loc.to(torch.int64)
starts = qsl64[:-1].tolist()
ends = qsl64[1:].tolist()
total = int(qsl64[-1].item())
chunk_lens: list[int] = []
seq_idx_chunks: list[int] = []
last_chunk_indices: list[int] = [-1] * len(starts)
for b, (s, e) in enumerate(zip(starts, ends)):
if e <= s:
# empty sequence
continue
pos = s
while pos < e:
# split at both sequence boundaries and physical chunk boundaries
room = chunk_size - (pos % chunk_size)
take = min(room, e - pos)
chunk_lens.append(int(take))
seq_idx_chunks.append(b)
last_chunk_indices[b] = len(chunk_lens) - 1
pos += take
# Exclusive prefix sum over logical-chunk lengths
if chunk_lens:
cu_chunk_seqlens = torch.tensor(
[0] + list(itertools.accumulate(chunk_lens)),
device=device,
dtype=torch.int32,
)
# Final boundary must equal total tokens
assert int(cu_chunk_seqlens[-1].item()) == total
else:
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
last_chunk_indices_t = (
torch.tensor(last_chunk_indices, device=device, dtype=torch.int32)
if len(starts) > 0
else torch.empty((0,), device=device, dtype=torch.int32)
)
seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32)
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
class Mamba2AttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
return Mamba2AttentionMetadataBuilder
@dataclass
class Mamba2AttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc_p: torch.Tensor
seq_lens: torch.Tensor
prep_initial_states: bool
chunk_size: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p: torch.Tensor | None
seq_idx_p: torch.Tensor | None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offests into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None
state_indices_tensor: torch.Tensor # shape: [batch,]
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models"
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba2AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
seq_lens = common_attn_metadata.seq_lens
query_start_loc_p = None
seq_idx_p = None
cu_chunk_seqlen_p = None
last_chunk_indices_p = None
# Need flags to indicate if there are initial states
has_initial_states_p = None
prep_initial_states = False
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# Additional cache-related varaiables:
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
# Compute seq_idx for prefill only
if num_prefills > 0:
# [batch,]
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
# The code below carefully constructs the chunks such that:
# 1. Chunks contain tokens from a *single* sequence only.
# 2. For every sequence, we are guaranteed that we can
# retrieve the mamba state *every* chunk_size tokens.
# Constraint (1) dramatically simplifies the mamba2 kernels.
# Constraint (2) dramatically simplifies the implementation
# of prefix caching for mamba2 (wip). We need to take care
# of the interaction with chunked prefill in order to
# satisfy constraint (2).
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
seq_idx_p = torch.as_tensor(
seq_idx, device=query_start_loc_p.device, dtype=torch.int32
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
seq_lens=seq_lens,
prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size,
has_initial_states_p=has_initial_states_p,
seq_idx_p=seq_idx_p,
state_indices_tensor=state_indices_tensor,
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
last_chunk_indices_p=last_chunk_indices_p,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
)
return attn_metadata

View File

@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from typing import ClassVar, TypeVar
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
M = TypeVar("M")
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_cudagraph_capture_size,
)
if self.vllm_config.cache_config.enable_prefix_caching:
self.state_indices_tensor = torch.empty(
(
self.decode_cudagraph_max_bs,
cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size,
),
),
dtype=torch.int32,
device=device,
)
self.block_idx_last_scheduled_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_last_computed_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
else:
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, (
"Mamba only supports decode-only full CUDAGraph capture. "
"Make sure all cudagraph capture sizes <= max_num_seq."
)
m.max_query_len = 1 # decode-only
return self.build(0, m)
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,
mamba_block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
# Block index of the last computed token
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
# which is <= block index for the first scheduled token
block_idx_first_scheduled_token = (
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
)
# which is <= block index of the last scheduled token
block_idx_last_scheduled_token = (
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
# -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
return (
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
)

View File

@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.attention.backends.mla.common import MLACommonBackend
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
AiterMLAImpl,
AiterMLAMetadataBuilder,
)
class AiterTritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "AITER_TRITON_MLA"
@staticmethod
def get_impl_cls() -> type["AiterTritonMLAImpl"]:
return AiterTritonMLAImpl
@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
class AiterTritonMLAImpl(AiterMLAImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
from aiter.ops.triton.mha import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
result = self.flash_attn_varlen_func(
q,
k,
v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
# Transpose the LSE if Triton MHA is used:
# (q.shape[0], num_q_heads) to (num_q_heads, q.shape[0])
if type(result) is tuple and return_softmax_lse:
output, lse = result
lse = lse.T.contiguous()
return (output, lse)
return result

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import ClassVar
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [128]
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
@staticmethod
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 10
class SM100Workspace:
def __init__(self, initial_workspace_size):
self._workspace_buf = torch.empty(
initial_workspace_size, device="cuda", dtype=torch.uint8
)
self._block_size = 128 # Forced to 128
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
# (assumes all devices are similar)
properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
self._sm_count = properties.multi_processor_count
def get_buf(self):
return self._workspace_buf
def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int):
batch_size = attn_metadata.num_reqs
max_seq_len = attn_metadata.max_query_len
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seq_len * self._block_size,
batch_size,
self._sm_count,
num_kv_splits=num_kv_splits,
)
if self._workspace_buf.shape[0] < workspace_size:
self._workspace_buf.resize_(workspace_size)
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
MAX_HEADS = 128
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
q_pad_num_heads=MAX_HEADS,
**mla_args,
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl"
)
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
# issues. In case the code hangs, use:
# FORCE_NUM_KV_SPLITS=1
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
if force_num_kv_splits:
logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
self._num_kv_splits = int(force_num_kv_splits)
else:
self._num_kv_splits = -1 # => Auto-detect
# Share workspace buffer across all executions
self._workspace = g_sm100_workspace
def _sm100_cutlass_mla_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
sm_scale: float,
num_kv_splits: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
assert kv_c_and_k_pe_cache.ndim == 3, (
"kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
kv_c_and_k_pe_cache.ndim
)
)
B_q, H, D_q_nope = q_nope.shape
B_q_2, H_2, D_q_pe = q_pe.shape
assert (B_q == B_q_2) and (H == H_2)
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
D_latent = 512
D_rope = 64
assert D_q_nope == D_latent
assert D_q_pe == D_rope
assert D_ckv == D_latent + D_rope
MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0
assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}."
)
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert seq_lens.dtype == torch.int32, (
f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
)
assert page_table.dtype == torch.int32, (
f"page_table.dtype needs to be int32 but got {page_table.dtype}."
)
dtype = (
torch.bfloat16
if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype
)
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
lse = (
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode
else torch.Tensor()
)
ops.sm100_cutlass_mla_decode(
out,
lse,
q_nope,
q_pe,
kv_c_and_k_pe_cache,
seq_lens,
page_table,
workspace,
sm_scale,
num_kv_splits,
)
if H < MAX_HEADS:
# Extract the subsets of the outputs
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
out = out[:, :H]
return out, lse
def _forward_decode(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
# Adjust workspace size (if necessary)
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
# Run MLA
o, lse = self._sm100_cutlass_mla_decode(
q_nope,
q_pe,
kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table,
self._workspace.get_buf(),
self.scale,
self._num_kv_splits,
)
return o, (lse if self.need_to_return_lse_for_decode else None)

View File

@@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.attention.utils.fa_utils import (
flash_attn_supports_mla,
get_flash_attn_version,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
@staticmethod
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
return FlashAttnMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
return FlashAttnMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 9
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if not flash_attn_supports_mla():
return "FlashAttention MLA not supported on this device"
return None
@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
query_start_loc: torch.Tensor
max_query_len: int
max_seq_len: int
scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0
@dataclass
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
pass
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
super().__init__(
kv_cache_spec,
layer_names,
vllm_config,
device,
FlashAttnMLAMetadata,
supports_dcp_with_varlen=(interleave_size == 1),
)
self.max_num_splits = 0 # No upper bound on the number of splits.
self.fa_aot_schedule = get_flash_attn_version() == 3
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.fa_aot_schedule:
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device=self.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = (
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
if vllm_is_batch_invariant():
self.max_num_splits = 1
def _schedule_decode(
self,
num_reqs,
cu_query_lens,
max_query_len,
seqlens,
max_seq_len,
causal,
max_num_splits,
):
if self.fa_aot_schedule:
return get_scheduler_metadata(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads * self.dcp_world_size,
num_heads_kv=1,
headdim=self.mla_dims.qk_rope_head_dim,
cache_seqlens=seqlens,
qkv_dtype=self.kv_cache_spec.dtype,
headdim_v=self.mla_dims.kv_lora_rank,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
num_splits=max_num_splits,
)
return None
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_is_batch_invariant():
max_num_splits = 1
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
cu_query_lens=query_start_loc_device,
max_query_len=max_query_len,
seqlens=seq_lens_device,
max_seq_len=max_seq_len,
causal=True,
max_num_splits=max_num_splits,
)
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
# Ensure the persistent buffer is large enough
assert n <= self.scheduler_metadata.shape[0], (
f"Scheduler metadata size {n} exceeds buffer size "
+ f"{self.scheduler_metadata.shape[0]}"
)
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
query_start_loc=query_start_loc_device,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
return metadata
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashAttnMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttnMLAImpl"
)
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
)
def _forward_decode(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
# to prevent invalid grid configuration during graph capture.
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
attn_out = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
q_v=q_nope,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=attn_metadata.decode.query_start_loc,
max_seqlen_k=attn_metadata.decode.max_seq_len,
seqused_k=attn_metadata.decode.seq_lens,
block_table=attn_metadata.decode.block_table,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=self.need_to_return_lse_for_decode,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
cp_world_size=self.dcp_world_size,
cp_rank=self.dcp_rank,
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
)
if self.need_to_return_lse_for_decode:
o, lse = attn_out
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
else:
o = attn_out
return o, None

View File

@@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
logger = init_logger(__name__)
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [32, 64]
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
@staticmethod
def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl
@staticmethod
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 10
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"
g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device="cuda",
)
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashInferMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferMLAImpl"
)
self._workspace_buffer = g_fi_workspace
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
def _forward_decode(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if isinstance(q, tuple):
q_nope, q_pe = q
q = torch.cat([q_nope, q_pe], dim=-1)
# trtllm API requires extra dimension q_len_per_request for MTP
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
logger.warning_once(
"""FlashInferMLAImpl got a query of uneven length.
This usually indicates an issue in batch reordering
or incorrect setup in dummy_run."""
)
q = q.unsqueeze(1)
else:
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])
# TODO: Return LSE pending support from Flashinfer API:
# https://github.com/flashinfer-ai/flashinfer/pull/1566
return o, None

View File

@@ -0,0 +1,317 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_dense_supported,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]
@staticmethod
def get_name() -> str:
return "FLASHMLA"
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major in [9, 10]
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if use_sparse:
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
return is_flashmla_sparse_supported()[1]
else:
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
return is_flashmla_dense_supported()[1]
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
# ^ TODO(matt): tune this
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
)
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
(num_sms, 8),
device=self.device,
dtype=torch.int32,
)
self.cg_buf_num_splits = torch.empty(
(vllm_config.scheduler_config.max_num_seqs + 1),
device=self.device,
dtype=torch.int32,
)
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> FlashMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
# we use the max but all should be the same due to uniform length requirement
max_query_len = query_lens_cpu.max().item()
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens_device,
num_q_tokens_per_head_k,
1, # MQA for the decode path
is_fp8_kvcache=self.is_fp8_kvcache,
)
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
:sm_parts
]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
is_supported, reason = is_flashmla_dense_supported()
assert is_supported, reason
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl"
)
def _forward_decode(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# TODO: (zyongye) decode function for mla here
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
# mypy assertion: q is now always a tensor
assert isinstance(q, torch.Tensor)
num_decodes = attn_metadata.num_decodes
q = reshape_query_for_spec_decode(q, num_decodes)
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
num_splits = attn_metadata.decode.num_splits
if vllm_is_batch_invariant():
device = q.device
dtype = torch.int32
B = q.shape[0]
# block_table shape: [batch_size, max_num_blocks_per_seq]
# The number of blocks per sequence is in the second dimension
topk = attn_metadata.decode.block_table.shape[-1]
B_TOPK = 64
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
end_block_idx = topk // B_TOPK
# Single partition => num_sm_parts = 1
# TileSchedulerMetaDataSize = 8, layout:
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
# begin_n_split_idx, _, _, _]
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
tile_scheduler_metadata[0, 0] = 0 # begin_idx
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
tile_scheduler_metadata[0, 3] = end_block_idx
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
# fields [5..7] stay 0
# Non-split path ignores num_splits, but the API requires it:
# zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
o = reshape_attn_output_for_spec_decode(o)
return o, lse

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,345 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
MultipleOf,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
split_prefill_chunks,
)
logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1 if current_platform.is_rocm() else 64]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 128]
@staticmethod
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
return DeepseekV32IndexerMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
@dataclass
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
@dataclass
class DeepseekV32IndexerMetadata:
# FIXME (zyongye)
# hacky way to access the data now, need to be in chunked meta
seq_lens: torch.Tensor
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# The dimension of the attention heads
head_dim: int
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
decode: DeepSeekV32IndexerDecodeMetadata | None = None
prefill: DeepseekV32IndexerPrefillMetadata | None = None
# TODO (zyongye) optimize this, this is now vibe coded
def kv_spans_from_batches(
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
start_seq_loc: 1D long tensor [B+1], cumulative counts of
selected tokens per batch.
Example: [0, 2, 4, 7] ->
batch sizes (selected) [2, 2, 3], N=7 tokens total.
seq_len_per_batch: 1D long tensor [B],
full sequence length (KV length) of each batch.
Example: [5, 9, 4].
Returns:
start_tensor: 1D long tensor [N], start offset in the
concatenated KV cache for each token's batch.
end_location: 1D long tensor [N],
**exclusive** end = start + token's local position.
(So the attended KV slice is kv[start:end].)
Assumes each batch contributes its full `seq_len_per_batch[i]`
keys to the KV cache, andthe selected tokens within a batch
are the **last** `counts[i]` positions of that sequence.
"""
q = start_seq_loc.to(dtype=torch.long)
L = seq_len_per_batch.to(dtype=torch.long)
assert q.dim() == 1 and L.dim() == 1
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
# Selected tokens per batch and totals
counts = q[1:] - q[:-1] # [B]
N = int(q[-1].item()) # total selected tokens
B = L.numel()
if N == 0:
return (
torch.empty(0, dtype=torch.long, device=device),
torch.empty(0, dtype=torch.long, device=device),
)
# KV start offsets per batch in the concatenated KV cache
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
# For each selected token, which batch does it belong to?
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
# Map batch KV start to each token
start_tensor = kv_starts_per_batch[batch_id] # [N]
# End-align local positions inside each batch:
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
L_expand = torch.repeat_interleave(L, counts) # [N]
m_expand = torch.repeat_interleave(counts, counts) # [N]
# position within the selected block: 1..counts[b]
pos_within = (
torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
)
local_pos = L_expand - m_expand + pos_within # [N], 1-based
end_location = start_tensor + local_pos # exclusive end
return start_tensor.int().to(device), end_location.int().to(device)
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
# Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
# The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
# within the flashmla_sparse workspace.
# For DeepSeek-V3.2, the max_model_len is 163840.
# 40 * 163840 * 132 = 865075200 bytes = 825 MB
return max_model_len * 40
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
reorder_batch_threshold: int = 1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
scheduler_config = self.vllm_config.scheduler_config
# NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config)
self.num_speculative_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config
else 0
)
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
props = torch.cuda.get_device_properties(self.device)
sm_count = props.multi_processor_count
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
)
# See: DeepGMM/csrc/apis/attention.hpp
self.scheduler_metadata_buffer = torch.empty(
(self.num_sms + 1, 2), dtype=torch.int32, device=self.device
)
def build_one_prefill_chunk(
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
):
prefill_query_start_loc = (
query_start_loc_cpu[reqs_start : reqs_end + 1]
- query_start_loc_cpu[reqs_start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = (
torch.cat(
[
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
]
)
.to(torch.int32)
.to(self.device)
)
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu[num_decodes:],
self.max_prefill_buffer_size,
request_offset=num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start,
reqs_end,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor,
)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks,
)
decode_metadata = None
if num_decodes > 0:
torch.diff(
common_attn_metadata.query_start_loc[: num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes],
)
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata

View File

@@ -0,0 +1,275 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionLayer, MultipleOf
from vllm.config import VllmConfig
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
@staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@dataclass
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: torch.Tensor | None = None
# The page indices of the paged kv cache
paged_kv_indices: torch.Tensor | None = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: torch.Tensor | None = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: torch.Tensor | None = None
# The dtype of MLA out tensor
attn_out_dtype: torch.dtype = torch.bfloat16
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
)
self.compilation_config = vllm_config.compilation_config
self.decode_attn_out_dtype = vllm_config.model_config.dtype
# kernel block size is always 1.
max_num_pages_per_req = vllm_config.model_config.max_model_len
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
# Preparing persistent buffers
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
self.paged_kv_indices = torch.zeros(
max_num_pages, dtype=torch.int32, device=device
)
self.paged_kv_last_page_len = torch.zeros(
max_num_reqs, dtype=torch.int32, device=device
)
self.qo_indptr = torch.arange(
0, max_num_reqs + 1, dtype=torch.int32, device=device
)
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata:
# kernel block size is always 1, although the kv block size is not 1.
device = self.device
num_reqs = seq_lens_device.size(0)
mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)
paged_kv_indptr = torch.cat(
[
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
]
)
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True
)
self.paged_kv_indices[num_actual_pages:].fill_(-1)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
self.paged_kv_indptr[: 1 + num_reqs].copy_(
paged_kv_indptr, non_blocking=True
)
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
self.paged_kv_last_page_len[:num_reqs].copy_(
paged_kv_last_page_len, non_blocking=True
)
self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
qo_indptr = self.qo_indptr[: 1 + num_reqs]
else:
qo_indptr = torch.arange(
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
)
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
attn_out_dtype=self.decode_attn_out_dtype,
)
return attn_metadata
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
assert num_heads == 16 or num_heads == 128, (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value."
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def _forward_decode(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(
B,
self.num_heads,
self.kv_lora_rank,
dtype=attn_metadata.decode.attn_out_dtype,
device=q.device,
)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
rocm_aiter_ops.mla_decode_fwd(
q,
kv_buffer,
o,
self.scale,
attn_metadata.decode.qo_indptr,
max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len,
q_scale=layer._q_scale,
kv_scale=layer._k_scale,
)
return o, None

View File

@@ -0,0 +1,325 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (
MLACommonBaseImpl,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA_SPARSE"
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
return ROCMAiterMLASparseMetadata
@staticmethod
def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
return ROCMAiterMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
return ROCMAiterMLASparseImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass
class ROCMAiterMLASparseMetadata:
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
req_id_per_token: torch.Tensor
block_size: int = 1
topk_tokens: int = 2048
@dataclass
class ROCMAiterMLASparseMetadataBuilder(
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.topk_tokens_tensor = torch.tensor(
[self.topk_tokens], device=device, dtype=torch.int32
)
self.max_model_len_tensor = torch.tensor(
[self.model_config.max_model_len], device=device, dtype=torch.int32
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty(
(1, 1), dtype=torch.int32, device=self.device
)
self.req_id_per_token_buffer = torch.empty(
(vllm_config.scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> ROCMAiterMLASparseMetadata:
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
metadata = ROCMAiterMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
)
return metadata
# Take from
# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_prefill.py#L72
def reference_mla_sparse_prefill(
q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int
) -> tuple[torch.Tensor, torch.Tensor]:
import math
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
skv = kv.shape[0]
sq = q.shape[0]
topk = indices.shape[-1]
dqk = q.shape[-1]
indices = indices[:, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= skv)
indices[invalid_indices_mask] = 0
qs = q # [s_q, h_q, d_qk]
kvs = kv[:, 0, :][indices].view(sq, topk, dqk) # [s_q, topk, d_qk]
attn_score = (qs @ kvs.transpose(1, 2)).float() # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
attn_score *= sm_scale * math.log2(math.e)
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score.to(q.dtype) @ kvs[:, :, :d_v]
return (result, lse)
class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
topk_indice_buffer: torch.Tensor | None = None,
indexer: Optional["Indexer"] = None,
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer
self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def _forward_bf16_kv(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: ROCMAiterMLASparseMetadata,
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1, 1, kv_c_and_k_pe_cache.shape[-1]
)
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = reference_mla_sparse_prefill(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512
)[0]
return output[:, : self.num_heads, :]
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: ROCMAiterMLASparseMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for ROCMAiterMLASparse"
)
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
if self.is_fp8bmm_enabled:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
ql_nope = rocm_aiter_ops.triton_fp8_bmm(
q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
else:
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
topk_indices = self.topk_indices_buffer[:num_actual_toks]
topk_indices_global = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
attn_out = self._forward_bf16_kv(
q, kv_cache, topk_indices_global, attn_metadata
)
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output

View File

@@ -0,0 +1,171 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
is_quantized_kv_cache,
)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
)
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
@staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl"
)
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported"
)
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
return super()._flash_attn_varlen_diff_headdims(
q,
k,
v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
def _forward_decode(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
if type(q) is tuple:
q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
q_num_heads = q.shape[1]
o = torch.zeros(
B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
)
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if vllm_is_batch_invariant() else 4
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
q_num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(
q,
kv_c_and_k_pe_cache,
kv_c_cache,
o,
lse,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens,
attn_logits,
num_kv_splits,
self.scale,
PAGE_SIZE,
)
return o, lse

View File

@@ -0,0 +1,436 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionType,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv, next_power_of_2
logger = init_logger(__name__)
# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
# from to fp32 directly. That's why it has a dtype mapping different from GPU
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.float8_e4m3fn,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"int8": torch.int8,
"uint8": torch.uint8,
}
try:
import tpu_inference # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
@requires_jax
def kv_cache_update_op_impl(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update,
(kv, slot_mapping, kv_cache, num_kv_update_slices),
{"page_size": page_size, "num_slices_per_block": num_slices_per_block},
)
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping,"
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size,"
"int num_slices_per_block)"
"-> Tensor",
)
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(
kv,
slot_mapping,
kv_cache,
num_kv_update_slices,
page_size,
num_slices_per_block,
)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
) -> torch.Tensor:
return kv_cache
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS"
@staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (
1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4
)
min_page_size = cdiv(
vllm_config.model_config.max_model_len, max_num_page_per_req
)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
@staticmethod
def get_max_num_seqs(model_len: int, page_size: int) -> int:
num_page_per_req = cdiv(model_len, page_size)
return 1024 * 1024 // 2 // num_page_per_req // 4
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
# TODO: This is a temporary fix for vmem OOM.
# For long model length, we use 16 page-size to avoid too much
# VMEM spill. A more robust solution should be implemented to
# handle VREG spills.
if vllm_config.model_config.max_model_len > 8192:
return 16
page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16
if page_size <= 16:
return 16
if page_size >= 256:
return 256
return page_size
@dataclass
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
num_kv_update_slices: torch.Tensor
num_slices_per_kv_cache_update_block: int
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
self.kv_cache_quantized_dtype = None
if kv_cache_dtype != "auto":
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
kv_cache_dtype.lower().strip()
)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: shape =
[num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl"
)
# For determine_available_memory case.
if kv_cache.numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
padded_head_size = (
cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
query = torch.nn.functional.pad(
query, (0, padded_head_size - self.head_size), value=0.0
)
key = torch.nn.functional.pad(
key, (0, padded_head_size - self.head_size), value=0.0
)
value = torch.nn.functional.pad(
value, (0, padded_head_size - self.head_size), value=0.0
)
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(
key,
value,
kv_cache,
slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices,
self.kv_cache_quantized_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
if self.kv_cache_quantized_dtype is not None and (
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0
):
raise ValueError("k_scale_float and v_scale_float must be non-zero")
output = torch.ops.xla.ragged_paged_attention(
query,
kv_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
output = output[:, :, : self.head_size]
return output.reshape(num_tokens, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
kv_cache_quantized_dtype: torch.dtype | None = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
) -> None:
"""Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
if kv_cache_quantized_dtype is not None:
dtype_info = torch.finfo(kv_cache_quantized_dtype)
key = key.to(torch.float32) / k_scale
# NOTE: clamp is added here to avoid out of range of quantized dtype
key = torch.clamp(key, dtype_info.min, dtype_info.max)
key = key.to(kv_cache_quantized_dtype)
value = value.to(torch.float32) / v_scale
value = torch.clamp(value, dtype_info.min, dtype_info.max)
value = value.to(kv_cache_quantized_dtype)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size)
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
kv_cache = kv_cache.flatten(0, 1)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv,
slot_mapping,
kv_cache,
num_kv_update_slices,
page_size,
num_slices_per_kv_cache_update_block,
)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)
# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):
if dtype.is_floating_point:
try:
return torch.finfo(dtype).bits
except TypeError:
pass
elif dtype.is_complex:
if dtype is torch.complex32:
return 32
elif dtype is torch.complex64:
return 64
elif dtype is torch.complex128:
return 128
else:
try:
return torch.iinfo(dtype).bits
# torch.iinfo cannot support int4, int2, bits8...
except TypeError:
pass
str_dtype = str(dtype)
# support torch.int4, torch.int5, torch.uint5...
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
return int(str_dtype[-1])
raise TypeError(f"Getting the bit width of {dtype} is not supported")
def get_dtype_packing(dtype):
bits = dtype_bits(dtype)
if 32 % bits != 0:
raise ValueError(
f"The bit width must be divisible by 32, but got bits={bits}, "
"dtype={dtype}"
)
return 32 // bits
def get_page_size_bytes(
block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype
) -> int:
"""Returns the size in bytes of one page of the KV cache."""
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
num_combined_kv_heads = num_kv_heads * 2
# NOTE: for the implicit padding in XLA
packing = get_dtype_packing(kv_cache_dtype)
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
return (
block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,206 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend,
RocmAttentionImpl,
RocmAttentionMetadataBuilder,
)
logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ROCM_AITER_UNIFIED_ATTN"
@staticmethod
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
return RocmAiterUnifiedAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
sinks,
)
logger.info_once(
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
)
from aiter.ops.triton.unified_attention import unified_attention
self.unified_attention = unified_attention
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for RocmAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (
cu_seqlens_q.shape[0] - 1,
key.shape[1] if key is not None else self.num_kv_heads,
)
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
)
return output

View File

@@ -0,0 +1,359 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
)
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
@dataclass
class RocmAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config
)
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> RocmAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
# Here we set the query start locs to 0. This is to
# cover up an invalid memory access in the prefix_prefil kernel
# that we run into during graph capture (#25985)
common_attn_metadata.query_start_loc.zero_()
common_attn_metadata.query_start_loc_cpu.zero_()
return attn_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> RocmAttentionMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device=self.device
)
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
attn_metadata = RocmAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
)
return attn_metadata
class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
if not cls.supports_head_size(head_size):
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set --attention-config.backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod
def get_name() -> str:
return "ROCM_ATTN"
@staticmethod
def get_impl_cls() -> type["RocmAttentionImpl"]:
return RocmAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
class RocmAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
RocmAttentionBackend.validate_head_size(head_size)
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for RocmAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}."
)
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for RocmAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
return output

View File

@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
class ShortConvAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
return ShortConvAttentionMetadataBuilder
@dataclass
class ShortConvAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states_p: torch.Tensor | None
# For causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class ShortConvAttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
):
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> ShortConvAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
has_initial_states_p = None
if num_prefills > 0:
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = ShortConvAttentionMetadata(
query_start_loc=query_start_loc,
state_indices_tensor=state_indices_tensor,
has_initial_states_p=has_initial_states_p,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata

View File

@@ -0,0 +1,428 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with TreeAttention."""
import ast
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "TREE_ATTN"
@staticmethod
def get_impl_cls() -> type["TreeAttentionImpl"]:
return TreeAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
return TreeAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class TreeAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
tree_attn_bias: torch.Tensor | None = None
# Cached Prefill/decode metadata.
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes :]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes :]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes :],
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc[: self.num_decodes + 1]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[: self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc,
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[: self.num_decodes],
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
tree_attn_bias=self.tree_attn_bias,
)
return self._cached_decode_metadata
class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
tree_choices: list[tuple[int, ...]] = (
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
)
# Construct the tree attention bias.
depth_counts = _get_depth_counts(tree_choices)
self.tree_attn_bias = _prepare_tree_attn_bias(
tree_choices,
depth_counts,
dtype=torch.float32,
device=device,
)
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TreeAttentionMetadata:
decode_threshold = self.tree_attn_bias.shape[0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=decode_threshold
)
)
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
return TreeAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
tree_attn_bias=self.tree_attn_bias,
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> TreeAttentionMetadata:
# Cache the original tree attention bias.
orig_tree_attn_bias = self.tree_attn_bias
if draft_index == 0:
# Use prefill for drafting at the root level.
self.tree_attn_bias = torch.empty(0)
else:
# Slice the tree attention bias for drafting. Exclude
# the root level.
start, end = 1, 1 + common_attn_metadata.max_query_len
self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous()
# Build attention bias.
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
# Reset the tree attention bias to the original value.
self.tree_attn_bias = orig_tree_attn_bias
return attn_metadata
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
# Count the number of choices at each depth of the tree.
depth_counts = []
prev_depth = 0
for path in sorted_tree_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
return depth_counts
def _prepare_tree_attn_bias(
sorted_tree_choices: list[tuple[int, ...]],
depth_counts: list[int],
dtype: torch.dtype | None,
device: torch.device | None,
) -> torch.Tensor:
# +1 comes from the additional root node.
tree_len = len(sorted_tree_choices) + 1
tree_attn_mask = torch.full(
(tree_len, tree_len), -torch.inf, device=device, dtype=dtype
)
# Set diagonal to all zeros. Each token should
# attend to itself.
mask_val = 0
for i in range(tree_len):
tree_attn_mask[i, i] = mask_val
# Set root to all zeros. All tokens attend to it.
tree_attn_mask[:, 0] = mask_val
# Set all ancestors to zeros.
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_tree_choice = sorted_tree_choices[start + j]
# Retrieve ancestor position.
if len(cur_tree_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_tree_choice) - 1):
ancestor_idx.append(
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
)
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
start += depth_counts[i]
return tree_attn_mask
class TreeAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TreeAttentionImpl."
)
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TreeAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with TreeAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for TreeAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1])
if prefill_meta := attn_metadata.prefill_metadata:
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_query_len,
seqused_k=decode_meta.seq_lens,
max_seqlen_k=decode_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
qq_bias=decode_meta.tree_attn_bias,
window_size=self.sliding_window,
block_table=decode_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

View File

@@ -0,0 +1,497 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""High-Performance Triton-only Attention layer."""
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
# constants
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments
@dataclass
class TritonAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
seq_threshold_3D: int
num_par_softmax_segments: int
softmax_segm_output: torch.Tensor
softmax_segm_max: torch.Tensor
softmax_segm_expsum: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
Empty ranges have start==end==0, which kernel skips via is_valid check.
"""
# TODO(Isotr0py): Move to model runner's attention metadata
# preparation to avoid duplicate computation.
if self.mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
]
# Return None if all ranges are trivial (only (0,0) placeholders)
if all(r == [(0, 0)] for r in range_lists):
return None
# Create 2D tensors with shape (num_ranges, 2) for each sequence
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config
)
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
# Check if CUDA Graphs are enabled for decode
self.decode_cudagraph_enabled = (
self.vllm_config.compilation_config.cudagraph_mode
in (
CUDAGraphMode.FULL_AND_PIECEWISE,
CUDAGraphMode.FULL_DECODE_ONLY,
CUDAGraphMode.FULL,
)
)
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
# A lower bound for num_q_blocks is the number of sequences.
# To ensure the minimum launch grid size is achieved, the number of sequences
# must be at least equal to the threshold below.
# If this threshold is not reached (i.e., the batch size is not large enough),
# the 3D kernel will be selected instead.
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
# Modify the threshold if needed.
if self.decode_cudagraph_enabled:
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified."
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
self.seq_threshold_3D = min(
capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
headdim_padded = next_power_of_2(self.headdim)
self.softmax_segm_output = torch.empty(
(
self.seq_threshold_3D,
self.num_heads_q,
self.num_par_softmax_segments,
headdim_padded,
),
dtype=torch.float32,
device=device,
)
self.softmax_segm_max = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
self.softmax_segm_expsum = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TritonAttentionMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device=self.device
)
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
attn_metadata = TritonAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
seq_threshold_3D=self.seq_threshold_3D,
num_par_softmax_segments=self.num_par_softmax_segments,
softmax_segm_output=self.softmax_segm_output,
softmax_segm_max=self.softmax_segm_max,
softmax_segm_expsum=self.softmax_segm_expsum,
)
return attn_metadata
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
return "TRITON_ATTN"
@staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@classmethod
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for TritonAttentionImpl"
)
self.attn_type = attn_type
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}."
)
self.supports_quant_query_input = current_platform.is_cuda()
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Paged Attention impl. in Triton.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for TritonAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(1)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
seq_threshold_3D = attn_metadata.seq_threshold_3D
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
softmax_segm_output = attn_metadata.softmax_segm_output
softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
mm_prefix_range=mm_prefix_range_tensor,
)
return output

File diff suppressed because it is too large Load Diff