update
This commit is contained in:
0
vllm/v1/attention/backends/__init__.py
Normal file
0
vllm/v1/attention/backends/__init__.py
Normal file
503
vllm/v1/attention/backends/cpu_attn.py
Normal file
503
vllm/v1/attention/backends/cpu_attn.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S390X)
|
||||
|
||||
|
||||
class CPUAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CPU_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""CPU attention supports decoder,
|
||||
encoder-only and encoder-decoder attention."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
|
||||
return CPUAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]:
|
||||
return CPUAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return 2, num_blocks, num_kv_heads, block_size, head_size
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUAttentionMetadata:
|
||||
isa: str
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
scheduler_metadata: torch.Tensor | None
|
||||
causal: bool = True
|
||||
|
||||
# can be removed after deprecate sdpa
|
||||
use_sdpa_prefill: bool = False
|
||||
num_decode_tokens: int = 0
|
||||
sdpa_attn_masks: list[torch.Tensor | None] | None = None
|
||||
sdpa_start_loc: torch.Tensor | None = None
|
||||
|
||||
|
||||
class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.use_sdpa_prefill = False
|
||||
reorder_batch_threshold = None
|
||||
if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH:
|
||||
# in this case, decode seqs are reordered to the front of prefill seqs
|
||||
# to split decode and prefill. Then use SDPA for prefill and
|
||||
# cpu_attention_with_kv_cache for decode
|
||||
reorder_batch_threshold = 1
|
||||
self.use_sdpa_prefill = True
|
||||
|
||||
self._init_reorder_batch_threshold(reorder_batch_threshold, False)
|
||||
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.num_kv_heads = vllm_config.model_config.get_num_kv_heads(parallel_config)
|
||||
self.num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
parallel_config
|
||||
)
|
||||
self.head_dim = kv_cache_spec.head_size
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.window_size = getattr(kv_cache_spec, "sliding_window", -1)
|
||||
if self.window_size is None:
|
||||
self.window_size = -1
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim)
|
||||
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> CPUAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = False if self.is_cross_attention else common_attn_metadata.causal
|
||||
|
||||
sdpa_start_loc = query_start_loc
|
||||
num_decode_tokens = 0
|
||||
if self.use_sdpa_prefill and causal:
|
||||
# Decoder, need reorder and truncate
|
||||
assert self.reorder_batch_threshold
|
||||
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True,
|
||||
)
|
||||
)
|
||||
num_reqs = num_decodes
|
||||
sdpa_start_loc = sdpa_start_loc[num_decodes:] - num_decode_tokens
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
query_start_loc = query_start_loc[: num_decodes + 1]
|
||||
block_table_tensor = block_table_tensor[:num_decodes]
|
||||
|
||||
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_reqs,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
seq_lens=seq_lens,
|
||||
dtype=self.dtype,
|
||||
query_start_loc=query_start_loc,
|
||||
causal=causal,
|
||||
sliding_window_size=self.window_size,
|
||||
isa=self.isa,
|
||||
enable_kv_split=True,
|
||||
)
|
||||
|
||||
attn_metadata = CPUAttentionMetadata(
|
||||
isa=self.isa,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
scheduler_metadata=sheduler_metadata,
|
||||
causal=causal,
|
||||
use_sdpa_prefill=self.use_sdpa_prefill,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
sdpa_start_loc=sdpa_start_loc,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class CPUAttentionBackendImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
if logits_soft_cap is not None and attn_type in (
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
):
|
||||
logger.warning_once(
|
||||
"CPU_ATTN does not support logits softcap for"
|
||||
" ENCODER and ENCODER_ONLY, outputs may be slightly off"
|
||||
)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN")
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.sinks = sinks
|
||||
if self.sinks is not None:
|
||||
assert self.sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
"heads in the layer"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: CPUAttentionMetadata | None,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for CPU attention backend.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, num_kv_heads, block_size, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for CPUAttentionBackendImpl"
|
||||
)
|
||||
|
||||
# For warming-up
|
||||
if attn_metadata is None:
|
||||
return output
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
return self._run_sdpa_forward(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
self.attn_type,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache, size are
|
||||
# [num_blocks, num_kv_heads, block_size, head_size]
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
ops.cpu_attn_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.isa,
|
||||
)
|
||||
|
||||
if attn_metadata.use_sdpa_prefill:
|
||||
assert self.sinks is None, "Attention sink is unsupported in SDPA prefill"
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
self._run_sdpa_forward(
|
||||
query[num_decode_tokens:num_actual_tokens],
|
||||
key[num_decode_tokens:num_actual_tokens],
|
||||
value[num_decode_tokens:num_actual_tokens],
|
||||
output[num_decode_tokens:num_actual_tokens],
|
||||
attn_metadata,
|
||||
self.attn_type,
|
||||
)
|
||||
num_actual_tokens = num_decode_tokens
|
||||
|
||||
if num_actual_tokens > 0:
|
||||
ops.cpu_attention_with_kv_cache(
|
||||
query=query[:num_actual_tokens],
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
output=output[:num_actual_tokens], # type: ignore
|
||||
query_start_loc=attn_metadata.query_start_loc,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes, # type: ignore
|
||||
sliding_window=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _run_sdpa_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: CPUAttentionMetadata,
|
||||
attn_type: str,
|
||||
) -> torch.Tensor:
|
||||
attn_masks = attn_metadata.sdpa_attn_masks
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.sdpa_start_loc,
|
||||
)
|
||||
elif self.sliding_window[0] != -1 or self.sliding_window[1] != -1:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.sdpa_start_loc,
|
||||
self.sliding_window[0],
|
||||
self.sliding_window[1],
|
||||
query.dtype,
|
||||
)
|
||||
else:
|
||||
attn_masks = [None] * (attn_metadata.sdpa_start_loc.size(0) - 1) # type: ignore
|
||||
attn_metadata.sdpa_attn_masks = attn_masks
|
||||
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
causal_attn = attn_type == AttentionType.DECODER
|
||||
|
||||
sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(len(attn_masks)):
|
||||
mask = attn_masks[i]
|
||||
start_q = sdpa_start_loc[i]
|
||||
end_q = sdpa_start_loc[i + 1]
|
||||
sub_out = (
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
query[None, :, start_q:end_q, :],
|
||||
key[None, :, start_q:end_q, :],
|
||||
value[None, :, start_q:end_q, :],
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal_attn and mask is None,
|
||||
scale=self.scale,
|
||||
enable_gqa=self.num_heads > self.num_kv_heads,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = sub_out
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
sdpa_start_loc: torch.Tensor,
|
||||
) -> list[torch.Tensor]:
|
||||
attn_biases: list[torch.Tensor] = []
|
||||
seq_num = sdpa_start_loc.size(0) - 1
|
||||
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(seq_num):
|
||||
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
|
||||
bias = torch.arange(seq_len, dtype=dtype) # type: ignore
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = (
|
||||
torch.empty((1, seq_len, seq_len), dtype=bias.dtype) # type: ignore
|
||||
.fill_(-torch.inf)
|
||||
.triu_(diagonal=1)
|
||||
)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
sdpa_start_loc: torch.Tensor,
|
||||
left_window_size: int,
|
||||
right_window_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> list[torch.Tensor]:
|
||||
attn_biases: list[torch.Tensor] = []
|
||||
seq_num = sdpa_start_loc.size(0) - 1
|
||||
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(seq_num):
|
||||
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
|
||||
mask = torch.full( # type: ignore
|
||||
(1, seq_len, seq_len), # type: ignore
|
||||
fill_value=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if right_window_size != -1:
|
||||
mask = torch.tril(mask, diagonal=right_window_size)
|
||||
if left_window_size != -1:
|
||||
mask = torch.triu(mask, diagonal=-left_window_size)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask)
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _get_attn_isa(
|
||||
dtype: torch.dtype, block_size: int, head_size: int | None = None
|
||||
) -> str:
|
||||
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
|
||||
return "vec16"
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
supports_arm = current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
supports_vxe = current_platform.get_cpu_architecture() == CpuArchEnum.S390X
|
||||
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
|
||||
return "amx"
|
||||
elif block_size % 32 == 0:
|
||||
if supports_arm:
|
||||
# support ARM NEON FMLA and BFMMLA (bf16) for block size 32
|
||||
return "neon"
|
||||
elif supports_vxe:
|
||||
return "vxe"
|
||||
else:
|
||||
return "vec"
|
||||
else:
|
||||
return "vec16"
|
||||
177
vllm/v1/attention/backends/fa_utils.py
Normal file
177
vllm/v1/attention/backends/fa_utils.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Track whether upstream flash-attn is available on ROCm.
|
||||
# Set during module initialization and never modified afterwards.
|
||||
# This module-level flag avoids repeated import attempts and ensures
|
||||
# consistent behavior (similar to IS_AITER_FOUND in _aiter_ops.py).
|
||||
_ROCM_FLASH_ATTN_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm._custom_ops import reshape_and_cache_flash
|
||||
# from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
# flash_attn_varlen_func,
|
||||
# get_scheduler_metadata,
|
||||
# )
|
||||
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
elif current_platform.is_xpu():
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._xpu_ops import xpu_ops
|
||||
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[assignment]
|
||||
get_scheduler_metadata = xpu_ops.get_scheduler_metadata # type: ignore[assignment]
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||
|
||||
# Mark that upstream flash-attn is available on ROCm
|
||||
_ROCM_FLASH_ATTN_AVAILABLE = True
|
||||
except ImportError:
|
||||
|
||||
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
|
||||
raise ImportError(
|
||||
"ROCm platform requires upstream flash-attn "
|
||||
"to be installed. Please install flash-attn first."
|
||||
)
|
||||
|
||||
# ROCm doesn't use scheduler metadata (FA3 feature), provide stub
|
||||
def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
|
||||
return None
|
||||
|
||||
# ROCm uses the C++ custom op for reshape_and_cache
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return 3
|
||||
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
if current_platform.is_rocm():
|
||||
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
|
||||
return None
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = (
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.flash_attn_version is not None
|
||||
):
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform, "
|
||||
"defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error(
|
||||
"Cannot use FA version %d is not supported due to %s",
|
||||
fa_version,
|
||||
fa_version_unsupported_reason(fa_version),
|
||||
)
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
return (
|
||||
get_flash_attn_version() == 3
|
||||
and current_platform.is_device_capability_family(90)
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_supports_sinks() -> bool:
|
||||
if current_platform.is_xpu():
|
||||
return True
|
||||
else:
|
||||
return get_flash_attn_version() == 3
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
return is_fa_version_supported(
|
||||
3
|
||||
) and current_platform.is_device_capability_family(90)
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
"""Check if flash_attn_varlen_func is available.
|
||||
|
||||
This function determines whether the flash_attn_varlen_func imported at module
|
||||
level is a working implementation or a stub.
|
||||
|
||||
Platform-specific sources:
|
||||
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
|
||||
- XPU: xpu_ops.flash_attn_varlen_func
|
||||
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
|
||||
|
||||
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
|
||||
which uses rocm_aiter_ops.flash_attn_varlen_func. The condition to use AITER is
|
||||
handled separately via _aiter_ops.is_aiter_found_and_supported().
|
||||
|
||||
Returns:
|
||||
bool: True if a working flash_attn_varlen_func implementation is available.
|
||||
"""
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
# CUDA and XPU always have flash_attn_varlen_func available
|
||||
return True
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# Use the flag set during module import to check if
|
||||
# upstream flash-attn was successfully imported
|
||||
return _ROCM_FLASH_ATTN_AVAILABLE
|
||||
|
||||
return False
|
||||
1293
vllm/v1/attention/backends/flash_attn.py
Normal file
1293
vllm/v1/attention/backends/flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
277
vllm/v1/attention/backends/flash_attn_diffkv.py
Normal file
277
vllm/v1/attention/backends/flash_attn_diffkv.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
|
||||
from .flash_attn import (
|
||||
FlashAttentionBackend,
|
||||
FlashAttentionImpl,
|
||||
FlashAttentionMetadata,
|
||||
cascade_attention,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttentionDiffKVBackend(FlashAttentionBackend):
|
||||
# Default to 128 for this backend
|
||||
head_size_v: int = 128
|
||||
|
||||
@classmethod
|
||||
def set_head_size_v(cls, head_size_v: int) -> None:
|
||||
cls.head_size_v = head_size_v
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_DIFFKV"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||
return FlashAttentionDiffKVImpl
|
||||
|
||||
# Do not modify the interface of get_kv_cache_shape,
|
||||
# but consider head_size_v when returning result.
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size + FlashAttentionDiffKVBackend.head_size_v,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD" and include_num_layers_dimension:
|
||||
# (num_blocks, num_layers, block_size,
|
||||
# num_kv_heads, head_size + head_size_v)
|
||||
return (1, 0, 2, 3, 4)
|
||||
elif cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3)
|
||||
elif cache_layout == "HND" and include_num_layers_dimension:
|
||||
# (num_blocks, num_kv_heads, num_layers,
|
||||
# block_size, head_size + head_size_v)
|
||||
return (1, 3, 0, 2, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 2, 1, 3)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
|
||||
class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size_v]
|
||||
kv_cache: shape =
|
||||
[num_blocks, block_size, num_kv_heads, head_size + head_size_v]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size_v]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for FlashAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
attn_type = self.attn_type
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
# Different head_size for K and V
|
||||
key_cache = kv_cache[..., : self.head_size]
|
||||
value_cache = kv_cache[..., self.head_size :]
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
|
||||
# kv_cache update for different head_size K and V
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
key_cache = key_cache.view(dtype)
|
||||
value_cache = value_cache.view(dtype)
|
||||
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
else:
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window)
|
||||
if self.sliding_window is not None
|
||||
else None
|
||||
)
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=sliding_window_size,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
return output
|
||||
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
output[:num_actual_tokens],
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
cu_query_lens=attn_metadata.query_start_loc,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
||||
prefix_kv_lens=attn_metadata.prefix_kv_lens,
|
||||
suffix_kv_lens=attn_metadata.suffix_kv_lens,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
max_num_splits=attn_metadata.max_num_splits,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
q_descale=layer._q_scale,
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
return output
|
||||
1772
vllm/v1/attention/backends/flashinfer.py
Normal file
1772
vllm/v1/attention/backends/flashinfer.py
Normal file
File diff suppressed because it is too large
Load Diff
1024
vllm/v1/attention/backends/flex_attention.py
Normal file
1024
vllm/v1/attention/backends/flex_attention.py
Normal file
File diff suppressed because it is too large
Load Diff
430
vllm/v1/attention/backends/gdn_attn.py
Normal file
430
vllm/v1/attention/backends/gdn_attn.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Backend for GatedDeltaNet attention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class GDNAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "GDN_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||
return GDNAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class GDNAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_spec_decodes: int
|
||||
num_spec_decode_tokens: int
|
||||
num_actual_tokens: int
|
||||
|
||||
has_initial_state: torch.Tensor | None = None
|
||||
|
||||
spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,]
|
||||
non_spec_query_start_loc: torch.Tensor | None = (
|
||||
None # shape: [batch - num_spec_decodes + 1,]
|
||||
)
|
||||
|
||||
spec_state_indices_tensor: torch.Tensor | None = None # shape: [batch, num_spec]
|
||||
non_spec_state_indices_tensor: torch.Tensor | None = (
|
||||
None # shape: [batch - num_spec_decodes,]
|
||||
)
|
||||
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
|
||||
spec_token_indx: torch.Tensor | None = None
|
||||
non_spec_token_indx: torch.Tensor | None = None
|
||||
|
||||
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
if self.speculative_config:
|
||||
assert self.speculative_config.num_speculative_tokens is not None
|
||||
self.num_spec: int = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
|
||||
self.decode_cudagraph_max_bs = (
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1)
|
||||
)
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.decode_cudagraph_max_bs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_sequence_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build( # type: ignore[override]
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
query_start_loc_cpu = m.query_start_loc_cpu
|
||||
context_lens_tensor = m.compute_num_computed_tokens()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
block_table_tensor = mamba_get_block_table_tensor(
|
||||
m.block_table_tensor,
|
||||
m.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
|
||||
spec_sequence_masks_cpu: torch.Tensor | None = None
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
|
||||
.sum()
|
||||
.item()
|
||||
== 0
|
||||
):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks_cpu.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
spec_sequence_masks_cpu = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks_cpu.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc_cpu = query_start_loc_cpu
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
assert spec_sequence_masks_cpu is not None
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
# Use CPU tensors to avoid CPU-GPU sync
|
||||
non_spec_query_lens_cpu = query_lens_cpu[~spec_sequence_masks_cpu]
|
||||
num_decodes = (non_spec_query_lens_cpu == 1).sum().item()
|
||||
# Exclude zero-length padded sequences from prefill count.
|
||||
num_zero_len = (non_spec_query_lens_cpu == 0).sum().item()
|
||||
num_prefills = non_spec_query_lens_cpu.size(0) - num_decodes - num_zero_len
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = (
|
||||
non_spec_query_lens_cpu.sum().item() - num_decode_tokens
|
||||
)
|
||||
num_spec_decode_tokens = (
|
||||
query_lens_cpu.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc_cpu[-1].item(),
|
||||
)
|
||||
spec_token_indx = torch.arange(
|
||||
spec_token_size,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
# Filter by spec_sequence_masks to exclude padded sequences
|
||||
spec_state_indices_tensor = block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = None
|
||||
# Padded sequences are always at the back, so the first
|
||||
# num_spec_decodes + 1 entries of query_start_loc already
|
||||
# contain the correct cumulative token counts.
|
||||
spec_query_start_loc = query_start_loc[: num_spec_decodes + 1]
|
||||
non_spec_query_start_loc = None
|
||||
non_spec_query_start_loc_cpu = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
index = torch.argsort(spec_token_masks, stable=True)
|
||||
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
spec_query_start_loc = torch.zeros(
|
||||
num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||
)
|
||||
non_spec_query_start_loc = torch.zeros(
|
||||
query_lens.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[~spec_sequence_masks],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
non_spec_query_start_loc_cpu = torch.zeros(
|
||||
query_lens_cpu.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens_cpu[~spec_sequence_masks_cpu],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc_cpu[1:],
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
assert non_spec_query_start_loc_cpu is not None
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(
|
||||
non_spec_query_start_loc_cpu,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
|
||||
# Function code counted on either presency non-spec decode or spec decode,
|
||||
# but not both.
|
||||
assert not (num_decodes > 0 and num_spec_decodes > 0), (
|
||||
f"num_decodes: {num_decodes}, num_spec_decodes: {num_spec_decodes}"
|
||||
)
|
||||
|
||||
# Prepare tensors for cudagraph
|
||||
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
|
||||
batch_size = m.num_actual_tokens
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_decodes == 0
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks[:num_spec_decodes], non_blocking=True
|
||||
)
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert non_spec_token_indx is not None and spec_token_indx is not None
|
||||
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
|
||||
non_spec_token_indx, non_blocking=True
|
||||
)
|
||||
non_spec_token_indx = self.non_spec_token_indx[
|
||||
: non_spec_token_indx.size(0)
|
||||
]
|
||||
|
||||
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
|
||||
spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
|
||||
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
|
||||
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
|
||||
|
||||
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||
:batch_size
|
||||
]
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||
non_spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index]
|
||||
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
|
||||
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
|
||||
|
||||
attn_metadata = GDNAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=m.num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_indx=spec_token_indx,
|
||||
non_spec_token_indx=non_spec_token_indx,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (
|
||||
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
|
||||
), (
|
||||
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||
f"Make sure batch size ({m.num_reqs}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
|
||||
)
|
||||
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||
89
vllm/v1/attention/backends/linear_attn.py
Normal file
89
vllm/v1/attention/backends/linear_attn.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class LinearAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "LINEAR_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
|
||||
return LinearAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> LinearAttentionMetadata:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
attn_metadata = LinearAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
31
vllm/v1/attention/backends/mamba1_attn.py
Normal file
31
vllm/v1/attention/backends/mamba1_attn.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
)
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "MAMBA1_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
|
||||
return Mamba1AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
pass
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
metadata_cls = Mamba1AttentionMetadata
|
||||
267
vllm/v1/attention/backends/mamba2_attn.py
Normal file
267
vllm/v1/attention/backends/mamba2_attn.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def compute_varlen_chunk_metadata(
|
||||
query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
|
||||
|
||||
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
|
||||
and a physical `chunk_size`, returns three tensors on the same device:
|
||||
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
|
||||
logical-chunk lengths (each logical chunk never crosses a sequence or
|
||||
physical-chunk boundary).
|
||||
- last_chunk_indices: (B,) int32 index of the last logical chunk
|
||||
for each sequence (=-1 for empty sequences).
|
||||
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
|
||||
chunk in order.
|
||||
|
||||
This is intentionally lightweight and CPU-side; it mirrors the metadata
|
||||
produced by the V1 Mamba2 meta-data builder and is exported so tests
|
||||
(and other callers) can avoid duplicating the logic.
|
||||
"""
|
||||
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
|
||||
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
|
||||
device = query_start_loc.device
|
||||
|
||||
qsl64 = query_start_loc.to(torch.int64)
|
||||
starts = qsl64[:-1].tolist()
|
||||
ends = qsl64[1:].tolist()
|
||||
total = int(qsl64[-1].item())
|
||||
|
||||
chunk_lens: list[int] = []
|
||||
seq_idx_chunks: list[int] = []
|
||||
last_chunk_indices: list[int] = [-1] * len(starts)
|
||||
|
||||
for b, (s, e) in enumerate(zip(starts, ends)):
|
||||
if e <= s:
|
||||
# empty sequence
|
||||
continue
|
||||
pos = s
|
||||
while pos < e:
|
||||
# split at both sequence boundaries and physical chunk boundaries
|
||||
room = chunk_size - (pos % chunk_size)
|
||||
take = min(room, e - pos)
|
||||
chunk_lens.append(int(take))
|
||||
seq_idx_chunks.append(b)
|
||||
last_chunk_indices[b] = len(chunk_lens) - 1
|
||||
pos += take
|
||||
|
||||
# Exclusive prefix sum over logical-chunk lengths
|
||||
if chunk_lens:
|
||||
cu_chunk_seqlens = torch.tensor(
|
||||
[0] + list(itertools.accumulate(chunk_lens)),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# Final boundary must equal total tokens
|
||||
assert int(cu_chunk_seqlens[-1].item()) == total
|
||||
else:
|
||||
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
|
||||
last_chunk_indices_t = (
|
||||
torch.tensor(last_chunk_indices, device=device, dtype=torch.int32)
|
||||
if len(starts) > 0
|
||||
else torch.empty((0,), device=device, dtype=torch.int32)
|
||||
)
|
||||
seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32)
|
||||
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "MAMBA2_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
||||
return Mamba2AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
prep_initial_states: bool = False
|
||||
chunk_size: int = 0
|
||||
|
||||
# Chunk-related metadata (only for prefill)
|
||||
seq_idx_p: torch.Tensor | None = None
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
||||
):
|
||||
metadata_cls = Mamba2AttentionMetadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
assert chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models"
|
||||
)
|
||||
self.chunk_size: int = chunk_size
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba2.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba2 (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Mamba2AttentionMetadata:
|
||||
common = self._compute_common_metadata(
|
||||
common_attn_metadata, num_accepted_tokens=kwargs.get("num_accepted_tokens")
|
||||
)
|
||||
|
||||
seq_idx_p = None
|
||||
cu_chunk_seqlen_p = None
|
||||
last_chunk_indices_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
# Compute seq_idx for prefill only
|
||||
if common.num_prefills > 0:
|
||||
prep_initial_states = (
|
||||
torch.any(common.has_initial_states_p).item()
|
||||
if common.has_initial_states_p is not None
|
||||
else False
|
||||
)
|
||||
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
return replace(
|
||||
common,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx_p=seq_idx_p,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
)
|
||||
464
vllm/v1/attention/backends/mamba_attn.py
Normal file
464
vllm/v1/attention/backends/mamba_attn.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
M = TypeVar("M", bound="BaseMambaAttentionMetadata")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMambaAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_reqs: int
|
||||
|
||||
# The following tensors only contain prefill requests and will be None if
|
||||
# the batch has no prefill requests.
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
query_start_loc_p: torch.Tensor | None
|
||||
num_computed_tokens_p: torch.Tensor | None
|
||||
state_indices_tensor_p: torch.Tensor | None
|
||||
|
||||
# The following tensors are used for decode requests and
|
||||
# speculative decoding compatibility, and will be None if the batch
|
||||
# has no decode requests.
|
||||
state_indices_tensor_d: torch.Tensor | None
|
||||
query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,]
|
||||
|
||||
# Number of accepted tokens for each spec sequence (for loading correct checkpoint)
|
||||
# Includes the bonus token (so minimum is 1)
|
||||
num_accepted_tokens: torch.Tensor | None # shape: [batch,]
|
||||
|
||||
# The following tensors are only used for prefix caching in all mode and
|
||||
# are None if disabled
|
||||
block_idx_last_scheduled_token: torch.Tensor | None
|
||||
block_idx_first_scheduled_token_p: torch.Tensor | None
|
||||
block_idx_last_computed_token: torch.Tensor | None
|
||||
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
metadata_cls: type[M]
|
||||
reorder_batch_threshold: int = 1
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
# Will be disabled if speculative decoding is used
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
# Enable speculative decoding support
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.num_spec_tokens: int = vllm_config.num_speculative_tokens
|
||||
self.use_spec_decode = self.num_spec_tokens > 0
|
||||
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.decode_cudagraph_max_bs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
max_num_blocks = cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size,
|
||||
)
|
||||
# Speculative decoding not supported with prefix caching,
|
||||
# so keep shape consistent with prefill buffer
|
||||
# TODO: reduce this size as needed for decode-only cudagraph capture
|
||||
self.state_indices_tensor_d = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
max_num_blocks,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_scheduled_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_computed_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.state_indices_tensor_d = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# For speculative decoding, we need to store the following buffers
|
||||
# for CUDA graph capture during decode
|
||||
if self.num_spec_tokens > 0:
|
||||
self.decode_num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
if self.use_spec_decode:
|
||||
self.supports_update_block_table = False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (
|
||||
m.max_query_len <= 1 + self.num_spec_tokens
|
||||
and m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
), (
|
||||
"Mamba only supports decode-only full CUDAGraph capture. "
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
)
|
||||
|
||||
assert m.max_query_len == 1 + self.num_spec_tokens # decode-only
|
||||
|
||||
num_accepted_tokens = None
|
||||
if self.num_spec_tokens > 0:
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
|
||||
return self.build(0, m, num_accepted_tokens=num_accepted_tokens)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
*,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
**kwargs: Any,
|
||||
) -> M:
|
||||
"""
|
||||
Default build implementation for Mamba-like attention backends.
|
||||
Subclasses (e.g., Mamba2) can override to add additional metadata.
|
||||
"""
|
||||
return self._compute_common_metadata(
|
||||
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
||||
)
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
mamba_block_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
# Block index of the last computed token
|
||||
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
|
||||
# which is <= block index for the first scheduled token
|
||||
block_idx_first_scheduled_token = (
|
||||
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
|
||||
)
|
||||
# which is <= block index of the last scheduled token
|
||||
block_idx_last_scheduled_token = (
|
||||
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
|
||||
)
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
block_idx_last_computed_token = torch.clamp(
|
||||
block_idx_last_computed_token, min=0
|
||||
)
|
||||
# -1 in the case we have a padded request (0 seq-len)
|
||||
block_idx_last_scheduled_token = torch.clamp(
|
||||
block_idx_last_scheduled_token, min=0
|
||||
)
|
||||
|
||||
return (
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
)
|
||||
|
||||
def _compute_common_metadata(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
*,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
) -> M:
|
||||
"""
|
||||
Compute metadata common to both Mamba1 and Mamba2.
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
|
||||
# Treat multi-token queries as decode requests when
|
||||
# speculative decoding is enabled. Otherwise, use the
|
||||
# default decode threshold to prevent misclassification
|
||||
# of prefill queries as decode requests.
|
||||
decode_threshold = (
|
||||
self.reorder_batch_threshold if num_accepted_tokens is not None else 1
|
||||
)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=decode_threshold
|
||||
)
|
||||
)
|
||||
|
||||
# Need flags to indicate if there are initial states
|
||||
has_initial_states_p = None
|
||||
query_start_loc_p = None
|
||||
query_start_loc_d = None
|
||||
num_computed_tokens = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
# for prefix caching
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
block_idx_last_computed_token = None
|
||||
block_idx_last_scheduled_token = None
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
# Additional cache-related varaiables:
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
|
||||
if state_indices_tensor.dim() == 1:
|
||||
state_indices_tensor = state_indices_tensor.unsqueeze(-1)
|
||||
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
if self.vllm_config.cache_config.mamba_cache_mode != "all":
|
||||
state_indices_tensor_d = state_indices_tensor_d[
|
||||
:, : 1 + self.num_spec_tokens
|
||||
]
|
||||
state_indices_tensor_p = state_indices_tensor_p[:, 0]
|
||||
|
||||
if num_decodes > 0 and self.use_spec_decode:
|
||||
assert num_accepted_tokens is not None
|
||||
query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1]
|
||||
num_accepted_tokens = num_accepted_tokens[:num_decodes]
|
||||
|
||||
if num_prefills > 0:
|
||||
if num_computed_tokens is None:
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_p = (
|
||||
num_computed_tokens[num_reqs - num_prefills : num_reqs] > 0
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(
|
||||
query_start_loc_p_cpu,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
)
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
|
||||
metadata = self.metadata_cls(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
query_start_loc_d=query_start_loc_d,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
num_reqs=num_reqs,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
|
||||
return self._update_metadata_for_cudagraph_capture(metadata)
|
||||
|
||||
def _update_metadata_for_cudagraph_capture(
|
||||
self,
|
||||
metadata: M,
|
||||
) -> M:
|
||||
"""
|
||||
Update the metadata for cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
state_indices_tensor_d = metadata.state_indices_tensor_d
|
||||
query_start_loc_d = metadata.query_start_loc_d
|
||||
num_accepted_tokens = metadata.num_accepted_tokens
|
||||
block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token
|
||||
block_idx_last_computed_token = metadata.block_idx_last_computed_token
|
||||
if (
|
||||
metadata.num_prefills == 0
|
||||
and metadata.num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
padded_bs = metadata.num_reqs
|
||||
self.state_indices_tensor_d[: metadata.num_decodes].copy_(
|
||||
state_indices_tensor_d, non_blocking=True
|
||||
)
|
||||
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
|
||||
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID
|
||||
|
||||
if self.use_spec_decode:
|
||||
assert query_start_loc_d is not None
|
||||
assert num_accepted_tokens is not None
|
||||
query_start_loc_d = query_start_loc_d[: padded_bs + 1]
|
||||
self.decode_num_accepted_tokens[: metadata.num_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.decode_num_accepted_tokens[:padded_bs]
|
||||
num_accepted_tokens[metadata.num_decodes :] = (
|
||||
1 # pad with 1st slot index
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
assert block_idx_last_scheduled_token is not None
|
||||
assert block_idx_last_computed_token is not None
|
||||
self.block_idx_last_scheduled_token[: metadata.num_decodes].copy_(
|
||||
block_idx_last_scheduled_token[: metadata.num_decodes],
|
||||
non_blocking=True,
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
: metadata.num_decode_tokens
|
||||
]
|
||||
|
||||
self.block_idx_last_computed_token[: metadata.num_decodes].copy_(
|
||||
block_idx_last_computed_token[: metadata.num_decodes],
|
||||
non_blocking=True,
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
: metadata.num_decode_tokens
|
||||
]
|
||||
|
||||
return replace(
|
||||
metadata,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
query_start_loc_d=query_start_loc_d,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
)
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: M,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
blk_table,
|
||||
metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
if state_indices_tensor.dim() == 1:
|
||||
state_indices_tensor = state_indices_tensor.unsqueeze(-1)
|
||||
|
||||
assert (
|
||||
metadata.num_prefills + metadata.num_decodes
|
||||
== state_indices_tensor.shape[0]
|
||||
), (
|
||||
"Mismatch in number of requests when updating block table."
|
||||
f" Expected {metadata.num_prefills + metadata.num_decodes}, "
|
||||
f"got {state_indices_tensor.shape[0]}."
|
||||
)
|
||||
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[metadata.num_decodes, metadata.num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
if self.vllm_config.cache_config.mamba_cache_mode != "all":
|
||||
state_indices_tensor_d = state_indices_tensor_d[
|
||||
:, : 1 + self.num_spec_tokens
|
||||
]
|
||||
state_indices_tensor_p = state_indices_tensor_p[:, 0]
|
||||
|
||||
new_metadata = replace(
|
||||
metadata,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
)
|
||||
|
||||
return self._update_metadata_for_cudagraph_capture(new_metadata)
|
||||
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
66
vllm/v1/attention/backends/mla/aiter_triton_mla.py
Normal file
66
vllm/v1/attention/backends/mla/aiter_triton_mla.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLABackend, AiterMLAImpl
|
||||
|
||||
|
||||
class AiterTritonMLABackend(AiterMLABackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "AITER_TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterTritonMLAImpl"]:
|
||||
return AiterTritonMLAImpl
|
||||
|
||||
|
||||
class AiterTritonMLAImpl(AiterMLAImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
from aiter.ops.triton.mha import flash_attn_varlen_func
|
||||
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
result = self.flash_attn_varlen_func( # type: ignore[call-arg]
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
# Transpose the LSE if Triton MHA is used:
|
||||
# (q.shape[0], num_q_heads) to (num_q_heads, q.shape[0])
|
||||
if type(result) is tuple and return_softmax_lse:
|
||||
output, lse = result
|
||||
lse = lse.T.contiguous()
|
||||
return (output, lse)
|
||||
return result
|
||||
279
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
279
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [128]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||
return CutlassMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
|
||||
class SM100Workspace:
|
||||
def __init__(self, initial_workspace_size):
|
||||
self._workspace_buf = torch.empty(
|
||||
initial_workspace_size, device="cuda", dtype=torch.uint8
|
||||
)
|
||||
|
||||
self._block_size = 128 # Forced to 128
|
||||
|
||||
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
|
||||
# (assumes all devices are similar)
|
||||
self._sm_count = num_compute_units(0)
|
||||
|
||||
def get_buf(self):
|
||||
return self._workspace_buf
|
||||
|
||||
def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int):
|
||||
batch_size = attn_metadata.num_reqs
|
||||
max_seq_len = attn_metadata.max_query_len
|
||||
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seq_len * self._block_size,
|
||||
batch_size,
|
||||
self._sm_count,
|
||||
num_kv_splits=num_kv_splits,
|
||||
)
|
||||
|
||||
if self._workspace_buf.shape[0] < workspace_size:
|
||||
self._workspace_buf.resize_(workspace_size)
|
||||
|
||||
|
||||
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
||||
|
||||
MAX_HEADS = 128
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
q_pad_num_heads=MAX_HEADS,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl"
|
||||
)
|
||||
|
||||
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
|
||||
# issues. In case the code hangs, use:
|
||||
# FORCE_NUM_KV_SPLITS=1
|
||||
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
|
||||
if force_num_kv_splits:
|
||||
logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
|
||||
self._num_kv_splits = int(force_num_kv_splits)
|
||||
else:
|
||||
self._num_kv_splits = -1 # => Auto-detect
|
||||
|
||||
# Share workspace buffer across all executions
|
||||
self._workspace = g_sm100_workspace
|
||||
|
||||
def _sm100_cutlass_mla_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
assert kv_c_and_k_pe_cache.ndim == 3, (
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
|
||||
kv_c_and_k_pe_cache.ndim
|
||||
)
|
||||
)
|
||||
|
||||
B_q, H, D_q_nope = q_nope.shape
|
||||
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||
assert (B_q == B_q_2) and (H == H_2)
|
||||
|
||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||
|
||||
D_latent = 512
|
||||
D_rope = 64
|
||||
assert D_q_nope == D_latent
|
||||
assert D_q_pe == D_rope
|
||||
assert D_ckv == D_latent + D_rope
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
assert B_block_table == B_q
|
||||
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
|
||||
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}."
|
||||
)
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert seq_lens.dtype == torch.int32, (
|
||||
f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||
)
|
||||
assert page_table.dtype == torch.int32, (
|
||||
f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
)
|
||||
|
||||
dtype = (
|
||||
torch.bfloat16
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q_nope.dtype
|
||||
)
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
lse = (
|
||||
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
||||
if self.need_to_return_lse_for_decode
|
||||
else torch.Tensor()
|
||||
)
|
||||
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out,
|
||||
lse,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
|
||||
return out, lse
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Adjust workspace size (if necessary)
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
# Run MLA
|
||||
o, lse = self._sm100_cutlass_mla_decode(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table,
|
||||
self._workspace.get_buf(),
|
||||
self.scale,
|
||||
self._num_kv_splits,
|
||||
)
|
||||
|
||||
return o, (lse if self.need_to_return_lse_for_decode else None)
|
||||
361
vllm/v1/attention/backends/mla/flashattn_mla.py
Normal file
361
vllm/v1/attention/backends/mla/flashattn_mla.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
|
||||
return FlashAttnMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
|
||||
return FlashAttnMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 9
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if not flash_attn_supports_mla():
|
||||
return "FlashAttention MLA not supported on this device"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
||||
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
|
||||
super().__init__(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
FlashAttnMLAMetadata,
|
||||
supports_dcp_with_varlen=(interleave_size == 1),
|
||||
)
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.fa_aot_schedule = get_flash_attn_version() == 3
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
|
||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
|
||||
# The +1 is for the tile_count_semaphore (synchronization).
|
||||
# The 4 slots per batch element (num_prepare_batch_vectors) are:
|
||||
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
|
||||
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
|
||||
max_batch_size = max(
|
||||
vllm_config.scheduler_config.max_num_seqs,
|
||||
self.max_cudagraph_size or 0,
|
||||
)
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
1 + round_up(max_batch_size, 4) * 4,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = (
|
||||
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
|
||||
)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
self.max_num_splits = 1
|
||||
|
||||
def _schedule_decode(
|
||||
self,
|
||||
num_reqs,
|
||||
cu_query_lens,
|
||||
max_query_len,
|
||||
seqlens,
|
||||
max_seq_len,
|
||||
causal,
|
||||
max_num_splits,
|
||||
):
|
||||
if self.fa_aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=num_reqs,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
num_heads_q=self.num_heads * self.dcp_world_size,
|
||||
num_heads_kv=1,
|
||||
headdim=self.mla_dims.qk_rope_head_dim,
|
||||
cache_seqlens=seqlens,
|
||||
qkv_dtype=self.kv_cache_spec.dtype,
|
||||
headdim_v=self.mla_dims.kv_lora_rank,
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
num_splits=max_num_splits,
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and self.max_cudagraph_size is not None
|
||||
and num_decode_tokens <= self.max_cudagraph_size
|
||||
):
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_device.shape[0],
|
||||
cu_query_lens=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens_device,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
# Ensure the persistent buffer is large enough
|
||||
assert n <= self.scheduler_metadata.shape[0], (
|
||||
f"Scheduler metadata size {n} exceeds buffer size "
|
||||
f"{self.scheduler_metadata.shape[0]}"
|
||||
)
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
query_start_loc=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttnMLAImpl"
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttnMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
|
||||
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
||||
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
|
||||
|
||||
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
|
||||
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
|
||||
# to prevent invalid grid configuration during graph capture.
|
||||
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
|
||||
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q=q_pe,
|
||||
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
||||
q_v=q_nope,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
||||
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
||||
seqused_k=attn_metadata.decode.seq_lens,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=self.need_to_return_lse_for_decode,
|
||||
fa_version=3, # only version 3 is supported
|
||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.max_num_splits,
|
||||
cp_world_size=self.dcp_world_size,
|
||||
cp_rank=self.dcp_rank,
|
||||
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
|
||||
)
|
||||
|
||||
if self.need_to_return_lse_for_decode:
|
||||
o, lse = attn_out
|
||||
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
|
||||
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
|
||||
else:
|
||||
o = attn_out
|
||||
return o, None
|
||||
202
vllm/v1/attention/backends/mla/flashinfer_mla.py
Normal file
202
vllm/v1/attention/backends/mla/flashinfer_mla.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||
return FlashInferMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||
return FlashInferMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
# FlashInfer MLA kernel requires qk_nope_head_dim == 128
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||
if qk_nope_head_dim != 128:
|
||||
return (
|
||||
f"FlashInfer MLA kernel requires qk_nope_head_dim == 128, "
|
||||
f"but got {qk_nope_head_dim}"
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl"
|
||||
)
|
||||
|
||||
self._workspace_buffer = g_fi_workspace
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if isinstance(q, tuple):
|
||||
q_nope, q_pe = q
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
|
||||
logger.warning_once(
|
||||
"""FlashInferMLAImpl got a query of uneven length.
|
||||
This usually indicates an issue in batch reordering
|
||||
or incorrect setup in dummy_run."""
|
||||
)
|
||||
q = q.unsqueeze(1)
|
||||
else:
|
||||
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/1566
|
||||
return o, None
|
||||
353
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
Normal file
353
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
Normal file
@@ -0,0 +1,353 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""FlashInfer MLA Sparse Attention Backend.
|
||||
|
||||
This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k
|
||||
for models like DeepSeek-V3.2 that use index-based sparse attention.
|
||||
|
||||
For sparse MLA:
|
||||
- block_tables shape changes from [batch_size, max_num_blocks] (dense)
|
||||
to [batch_size, q_len_per_request, sparse_mla_top_k] (sparse)
|
||||
- The sparse indices represent physical cache slot positions to attend to
|
||||
- sparse_mla_top_k parameter must be set to the topk value
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.sparse_utils import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLASparseBackend(AttentionBackend):
|
||||
"""FlashInfer MLA backend with sparse attention support.
|
||||
|
||||
This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k
|
||||
for models like DeepSeek-V3.2 that use index-based sparse attention.
|
||||
"""
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLASparseImpl"]:
|
||||
return FlashInferMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLASparseMetadataBuilder"]:
|
||||
return FlashInferMLASparseMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
# FlashInfer sparse MLA targets Blackwell (SM 10.x)
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
# FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||
if qk_nope_head_dim != 128:
|
||||
return (
|
||||
f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, "
|
||||
f"but got {qk_nope_head_dim}"
|
||||
)
|
||||
# Check for index_topk which indicates sparse model
|
||||
if not hasattr(hf_text_config, "index_topk"):
|
||||
return "FlashInfer MLA Sparse requires model with index_topk config"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMLASparseMetadata(AttentionMetadata):
|
||||
"""Attention metadata for FlashInfer MLA Sparse backend."""
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
num_actual_tokens: int
|
||||
|
||||
# Query start locations
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
|
||||
# Sequence lengths for all requests (context + query)
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# Sparse-specific
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
|
||||
class FlashInferMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashInferMLASparseMetadata]
|
||||
):
|
||||
"""Builder for FlashInfer MLA Sparse attention metadata."""
|
||||
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashInferMLASparseMetadata:
|
||||
cm = common_attn_metadata
|
||||
num_tokens = cm.num_actual_tokens
|
||||
|
||||
# Build req_id_per_token mapping
|
||||
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token_tensor = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
return FlashInferMLASparseMetadata(
|
||||
num_reqs=cm.num_reqs,
|
||||
max_query_len=cm.max_query_len,
|
||||
max_seq_len=cm.max_seq_len,
|
||||
num_actual_tokens=cm.num_actual_tokens,
|
||||
query_start_loc=cm.query_start_loc,
|
||||
slot_mapping=cm.slot_mapping,
|
||||
block_table=cm.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token_tensor,
|
||||
seq_lens=cm.seq_lens,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
)
|
||||
|
||||
|
||||
# Global workspace buffer (lazily initialized)
|
||||
_fi_sparse_workspace: torch.Tensor | None = None
|
||||
|
||||
|
||||
def _get_workspace_buffer(device: torch.device) -> torch.Tensor:
|
||||
global _fi_sparse_workspace
|
||||
if _fi_sparse_workspace is None:
|
||||
_fi_sparse_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
return _fi_sparse_workspace
|
||||
|
||||
|
||||
class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata]):
|
||||
"""FlashInfer MLA Sparse implementation.
|
||||
|
||||
Uses the TRT-LLM MLA kernel with sparse_mla_top_k parameter for
|
||||
sparse attention computation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLASparseImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLASparseImpl"
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# MLA-specific dimensions
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.qk_nope_head_dim: int = mla_args["qk_nope_head_dim"]
|
||||
self.qk_rope_head_dim: int = mla_args["qk_rope_head_dim"]
|
||||
|
||||
assert indexer is not None, "Indexer required for sparse MLA"
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
|
||||
self._workspace_buffer: torch.Tensor | None = None
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMLASparseMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
topk_indices_physical, seq_lens = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token[:num_actual_toks],
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
return_valid_counts=True,
|
||||
)
|
||||
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = _get_workspace_buffer(q.device)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q.unsqueeze(1),
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=topk_indices_physical.unsqueeze(1),
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=attn_metadata.topk_tokens,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
sparse_mla_top_k=attn_metadata.topk_tokens,
|
||||
)
|
||||
return o.view(-1, o.shape[-2], o.shape[-1]), None
|
||||
317
vllm/v1/attention/backends/mla/flashmla.py
Normal file
317
vllm/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,317 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
FlashMLASchedMeta,
|
||||
flash_mla_with_kvcache,
|
||||
flash_mla_with_kvcache_fp8,
|
||||
get_mla_metadata,
|
||||
get_mla_metadata_dense_fp8,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if use_sparse:
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_sparse_supported
|
||||
|
||||
return is_flashmla_sparse_supported()[1]
|
||||
else:
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
return is_flashmla_dense_supported()[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
scheduler_metadata: FlashMLASchedMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
||||
# ^ TODO(matt): tune this
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
|
||||
)
|
||||
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
|
||||
num_sms = num_compute_units(self.device.index)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||
scheduler_metadata, _ = get_mla_metadata(
|
||||
seq_lens_device,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||
)
|
||||
if self.is_fp8_kvcache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
|
||||
seq_lens_device,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
||||
scheduler_metadata.num_splits = num_splits
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
is_supported, reason = is_flashmla_dense_supported()
|
||||
assert is_supported, reason
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl"
|
||||
)
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO: (zyongye) decode function for mla here
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
# mypy assertion: q is now always a tensor
|
||||
assert isinstance(q, torch.Tensor)
|
||||
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
|
||||
scheduler_metadata = attn_metadata.decode.scheduler_metadata
|
||||
if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
|
||||
device = q.device
|
||||
dtype = torch.int32
|
||||
|
||||
B = q.shape[0]
|
||||
# block_table shape: [batch_size, max_num_blocks_per_seq]
|
||||
# The number of blocks per sequence is in the second dimension
|
||||
topk = attn_metadata.decode.block_table.shape[-1]
|
||||
B_TOPK = 64
|
||||
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
|
||||
end_block_idx = topk // B_TOPK
|
||||
|
||||
# Single partition => num_sm_parts = 1
|
||||
# TileSchedulerMetaDataSize = 8, layout:
|
||||
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
|
||||
# begin_n_split_idx, _, _, _]
|
||||
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
|
||||
tile_scheduler_metadata[0, 0] = 0 # begin_idx
|
||||
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
|
||||
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
|
||||
tile_scheduler_metadata[0, 3] = end_block_idx
|
||||
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
|
||||
# fields [5..7] stay 0
|
||||
|
||||
# Non-split path ignores num_splits, but the API requires it:
|
||||
# zeros of length B+1
|
||||
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
||||
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
||||
scheduler_metadata.num_splits = num_splits
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
o, lse = flash_mla_with_kvcache_fp8(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
|
||||
num_splits=scheduler_metadata.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
else:
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=scheduler_metadata,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
is_fp8_kvcache=False,
|
||||
)
|
||||
|
||||
o = reshape_attn_output_for_spec_decode(o)
|
||||
|
||||
return o, lse
|
||||
847
vllm/v1/attention/backends/mla/flashmla_sparse.py
Normal file
847
vllm/v1/attention/backends/mla/flashmla_sparse.py
Normal file
@@ -0,0 +1,847 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.sparse_utils import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
FlashMLASchedMeta,
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# For FP8 sparse attention we have two impelementations:
|
||||
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
|
||||
# done by treating all tokens as single batch.
|
||||
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
|
||||
# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
|
||||
# the FP8 decode kernel for decode.
|
||||
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
|
||||
# prefill kernel requires padding the numer of heads to 128 while the decode does not
|
||||
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
|
||||
# batch mode (#2).
|
||||
MIN_HEADS_FOR_BF16_PREFILL = 32
|
||||
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
structured as:
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
`float8_e4m3` values.
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
the second for the next 128, and so on.
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8_ds_mla",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
||||
return FlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
||||
return FlashMLASparseImpl
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if cache_dtype_str == "fp8_ds_mla":
|
||||
# custom storage fromat is 656 bytes
|
||||
# see FlashMLA readme.md for details
|
||||
return (num_blocks, block_size, 656)
|
||||
else:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata(AttentionMetadata):
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
@dataclass
|
||||
class FP8KernelMetadata:
|
||||
scheduler_metadata: FlashMLASchedMeta
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
@dataclass
|
||||
class FP8SeparatePrefillDecode:
|
||||
@dataclass
|
||||
class Decode:
|
||||
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
|
||||
decode_query_len: int # needed for reshape in spec decode
|
||||
|
||||
@dataclass
|
||||
class Prefill:
|
||||
# Sequence lengths (context + query) for prefill requests
|
||||
# Shape: [num_prefill_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# Request ID for each token: -1 for decode tokens, request index
|
||||
# (0, 1, 2, ...) for prefill tokens.
|
||||
# Shape: [num_actual_tokens]
|
||||
request_ids: torch.Tensor
|
||||
|
||||
# Workspace start offsets for all prefill requests
|
||||
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
|
||||
# 0-indexed within each chunk. Used to map prefill tokens to workspace
|
||||
# offsets in convert_logical_index_to_physical_index
|
||||
workspace_starts: torch.Tensor
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""Metadata for a chunk of prefill requests.
|
||||
|
||||
Prefill requests may be chunked to fit within the fixed workspace size.
|
||||
"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
tokens_slice: slice
|
||||
block_table: torch.Tensor
|
||||
req_start_idx: int
|
||||
workspace_starts: torch.Tensor
|
||||
chunk_tot_seqlen: int
|
||||
|
||||
chunks: list[Chunk]
|
||||
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
|
||||
decode: Decode | None = None
|
||||
prefill: Prefill | None = None
|
||||
|
||||
fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
|
||||
fp8_use_mixed_batch: bool = False
|
||||
|
||||
|
||||
def get_prefill_workspace_size(max_model_len: int):
|
||||
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
# Memory usage: 5 * max_model_len * 576 * 2 bytes
|
||||
# Example: DeepSeek-V3.2 with max_model_len=163840 ->
|
||||
# 5 * 163840 * 576 * 2 = ~900 MB
|
||||
# This fits nicely below the typical MoE workspace size of >2GB so this is "free"
|
||||
return max_model_len * 5
|
||||
|
||||
|
||||
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
# Treat requests with query length <= 1 as decodes to match the
|
||||
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
sm_count = num_compute_units(device.index)
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
# FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
|
||||
self.fp8_decode_padded_heads = (
|
||||
FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
|
||||
)
|
||||
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
# Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
|
||||
self.topk_tokens_tensor = torch.full(
|
||||
(max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
|
||||
)
|
||||
# Shape: [max_num_seqs], all elements = max_model_len
|
||||
self.max_model_len_tensor = torch.full(
|
||||
(max_num_seqs,),
|
||||
self.model_config.max_model_len,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty(
|
||||
(max_num_seqs, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/api/sparse_decode.h
|
||||
# For sparse FP8 decode, the formula depends on architecture:
|
||||
# - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
|
||||
# - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
|
||||
# - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
|
||||
# For max buffer size, use s_q = 1 (the case that produces largest output)
|
||||
# Use padded head count since that's what will be passed to the kernel
|
||||
h_q = self.fp8_decode_padded_heads
|
||||
if current_platform.is_device_capability_family(100):
|
||||
# SM100 head64 or head64x2 uses full SM count
|
||||
max_num_sm_parts = sm_count
|
||||
else:
|
||||
# SM90 uses h_q/64 divisor
|
||||
max_num_sm_parts = sm_count // max(1, h_q // 64)
|
||||
self.tile_scheduler_metadata_buffer = torch.empty(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
# see: FlashMLA/csrc/params.h
|
||||
(max_num_sm_parts, 8),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
# Sized for per-request batching (num_decodes + 1)
|
||||
self.num_splits_buffer = torch.empty(
|
||||
(max_num_seqs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _build_fp8_mixed_decode_prefill(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
|
||||
"""Build FP8 metadata treating all tokens as one mixed batch.
|
||||
|
||||
This matches main branch's approach and avoids the BF16 prefill kernel
|
||||
which has head padding overhead when num_heads is small (high TP case).
|
||||
"""
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
# Use padded head count since that's what the kernel will see
|
||||
padded_heads = self.fp8_decode_padded_heads
|
||||
|
||||
# Build metadata for all tokens as a single batch
|
||||
scheduler_metadata, _ = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
|
||||
num_q_tokens_per_head_k=num_tokens * padded_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=padded_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
cache_lens=self.max_model_len_tensor[:1],
|
||||
dummy_block_table=self.dummy_block_table[:1],
|
||||
)
|
||||
|
||||
return fp8_metadata
|
||||
|
||||
def _build_fp8_separate_prefill_decode(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold or 1,
|
||||
require_uniform=True,
|
||||
)
|
||||
)
|
||||
|
||||
FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
|
||||
fp8_metadata = FP8Meta(
|
||||
num_decodes=num_decodes,
|
||||
num_prefills=num_prefills,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
)
|
||||
|
||||
# Extract prefill sequence lengths (context + query, not just query)
|
||||
# Decode requests come first in the batch, prefill requests follow
|
||||
prefill_seq_lens = None
|
||||
prefill_request_id = None
|
||||
prefill_workspace_starts = None
|
||||
prefill_chunks = None
|
||||
|
||||
# For pure decode batches, prefill_request_id will be None
|
||||
# For mixed batches, it will have -1 for decode and request_id for prefill
|
||||
if num_prefills > 0:
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
|
||||
prefill_seq_lens = seq_lens[num_decodes:]
|
||||
|
||||
# Build prefill_request_id: -1 for decode, request index for
|
||||
# prefill. This enables a single
|
||||
# convert_logical_index_to_physical_index call for all tokens
|
||||
prefill_request_id = torch.full(
|
||||
(num_tokens,), -1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
# Map prefill tokens to their request IDs (0, 1, 2, ...)
|
||||
for req_idx in range(num_prefills):
|
||||
# Get query token range for this prefill request
|
||||
global_req_idx = num_decodes + req_idx
|
||||
req_query_start = query_start_loc_cpu[global_req_idx]
|
||||
req_query_end = query_start_loc_cpu[global_req_idx + 1]
|
||||
prefill_request_id[req_query_start:req_query_end] = req_idx
|
||||
|
||||
# will be adjusted by chunk loop
|
||||
prefill_workspace_starts_cpu = torch.zeros(
|
||||
num_prefills, dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
prefill_workspace_starts_cpu[1:] = torch.cumsum(
|
||||
prefill_seq_lens_cpu[:-1], dim=0
|
||||
)
|
||||
# populated by non-blocking copy after prefill_workspace_starts_cpu is
|
||||
# updated by each chunk
|
||||
prefill_workspace_starts = torch.empty(
|
||||
num_prefills, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Chunk prefill requests to fit within workspace size
|
||||
max_prefill_buffer_size = get_prefill_workspace_size(
|
||||
self.vllm_config.model_config.max_model_len
|
||||
)
|
||||
chunk_bounds = split_prefill_chunks(
|
||||
prefill_seq_lens_cpu, max_prefill_buffer_size
|
||||
)
|
||||
|
||||
prefill_chunks = []
|
||||
for chunk_start, chunk_end in chunk_bounds:
|
||||
# Adjust workspace_starts in-place per chunk to be
|
||||
# 0-indexed within each chunk
|
||||
# Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
|
||||
# Initial: workspace_starts=[0,10,25,45]
|
||||
# After: workspace_starts=[0,10,0,20]
|
||||
# (chunk 0 starts at 0, chunk 1 starts at 0)
|
||||
offset = prefill_workspace_starts_cpu[chunk_start].item()
|
||||
prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset
|
||||
|
||||
chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
|
||||
chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
|
||||
token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
|
||||
token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
|
||||
tokens_slice = slice(token_start, token_end)
|
||||
|
||||
# Create chunk view of gpu tensor
|
||||
chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
|
||||
chunk_block_table = common_attn_metadata.block_table_tensor[
|
||||
num_decodes + chunk_start : num_decodes + chunk_end
|
||||
]
|
||||
|
||||
prefill_chunks.append(
|
||||
FP8Meta.Prefill.Chunk(
|
||||
seq_lens=chunk_seq_lens,
|
||||
tokens_slice=tokens_slice,
|
||||
block_table=chunk_block_table,
|
||||
req_start_idx=chunk_start,
|
||||
workspace_starts=chunk_workspace_starts,
|
||||
chunk_tot_seqlen=chunk_tot_seqlen,
|
||||
)
|
||||
)
|
||||
|
||||
prefill_workspace_starts.copy_(
|
||||
prefill_workspace_starts_cpu, non_blocking=True
|
||||
)
|
||||
|
||||
fp8_metadata.prefill = FP8Meta.Prefill(
|
||||
seq_lens=prefill_seq_lens,
|
||||
request_ids=prefill_request_id,
|
||||
workspace_starts=prefill_workspace_starts,
|
||||
chunks=prefill_chunks,
|
||||
)
|
||||
|
||||
if num_decodes > 0:
|
||||
# Compute decode_query_len for spec decode (uniform due to require_uniform)
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
|
||||
|
||||
# Use padded head count since that's what the kernel will see
|
||||
padded_heads = self.fp8_decode_padded_heads
|
||||
scheduler_metadata, _ = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
|
||||
num_q_tokens_per_head_k=decode_query_len * padded_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=padded_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
dummy_block_table=self.dummy_block_table[:num_decodes],
|
||||
cache_lens=self.max_model_len_tensor[:num_decodes],
|
||||
)
|
||||
fp8_metadata.decode = FP8Meta.Decode(
|
||||
kernel_metadata=kernel_meta,
|
||||
decode_query_len=decode_query_len,
|
||||
)
|
||||
|
||||
return fp8_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashMLASparseMetadata:
|
||||
cm = common_attn_metadata
|
||||
num_tokens = cm.num_actual_tokens
|
||||
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata: (
|
||||
FlashMLASparseMetadata.FP8SeparatePrefillDecode
|
||||
| FlashMLASparseMetadata.FP8KernelMetadata
|
||||
| None
|
||||
) = None
|
||||
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
|
||||
if self.use_fp8_kv_cache:
|
||||
if fp8_use_mixed_batch:
|
||||
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
|
||||
else:
|
||||
fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)
|
||||
|
||||
metadata = FlashMLASparseMetadata(
|
||||
num_reqs=cm.num_reqs,
|
||||
max_query_len=cm.max_query_len,
|
||||
max_seq_len=cm.max_seq_len,
|
||||
num_actual_tokens=cm.num_actual_tokens,
|
||||
query_start_loc=cm.query_start_loc,
|
||||
slot_mapping=cm.slot_mapping,
|
||||
block_table=cm.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
fp8_use_mixed_batch=fp8_use_mixed_batch,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
@staticmethod
|
||||
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
|
||||
# FP8 decode kernel only supports h_q = 64 or 128
|
||||
# Compute padded head count for decode
|
||||
return 64 if num_heads <= 64 else 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
# Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
|
||||
self.prefill_padding = (
|
||||
128 if current_platform.is_device_capability_family(100) else 64
|
||||
)
|
||||
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# Reserve workspace during initialization
|
||||
vllm_config = get_current_vllm_config()
|
||||
assert vllm_config is not None and vllm_config.model_config is not None
|
||||
prefill_workspace_size = get_prefill_workspace_size(
|
||||
vllm_config.model_config.max_model_len
|
||||
)
|
||||
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
|
||||
(self.prefill_bf16_workspace,) = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(self.prefill_workspace_shape, torch.bfloat16)
|
||||
)
|
||||
)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
)
|
||||
|
||||
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
|
||||
|
||||
def _forward_fp8_kv_separate_prefill_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
fp8_metadata = attn_metadata.fp8_extra_metadata
|
||||
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
|
||||
num_decodes = fp8_metadata.num_decodes
|
||||
|
||||
prefill_request_ids = None
|
||||
prefill_workspace_starts = None
|
||||
has_prefill_workspace = False
|
||||
if fp8_metadata.prefill is not None:
|
||||
prefill_request_ids = fp8_metadata.prefill.request_ids
|
||||
prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
|
||||
has_prefill_workspace = True
|
||||
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
# For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
|
||||
# For BF16 cache: always use global cache slots (no workspace)
|
||||
# prefill_workspace_starts has been adjusted in-place per chunk so
|
||||
# prefill indices automatically come out chunk-local
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
HAS_PREFILL_WORKSPACE=has_prefill_workspace,
|
||||
prefill_workspace_request_ids=prefill_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
fp8_metadata = attn_metadata.fp8_extra_metadata
|
||||
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
|
||||
|
||||
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
|
||||
# Reshape q: (num_decode_tokens, num_heads, head_dim)
|
||||
# -> (num_decodes, seq_len, num_heads, head_dim)
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
seq_len = q.shape[1]
|
||||
# Reshape topk_indices: (num_decode_tokens, topk)
|
||||
# -> (num_decodes, seq_len, topk)
|
||||
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
|
||||
assert fp8_metadata.decode is not None
|
||||
attn_out, _ = self._fp8_flash_mla_kernel(
|
||||
q=q,
|
||||
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
|
||||
topk_indices=topk_indices,
|
||||
kernel_metadata=fp8_metadata.decode.kernel_metadata,
|
||||
)
|
||||
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
|
||||
# -> (num_decode_tokens, num_heads, head_dim_v)
|
||||
return reshape_attn_output_for_spec_decode(attn_out)
|
||||
|
||||
num_decode_tokens = fp8_metadata.num_decode_tokens
|
||||
num_prefill_tokens = fp8_metadata.num_prefill_tokens
|
||||
|
||||
# Pure decode: direct call without allocation
|
||||
if num_decode_tokens > 0 and num_prefill_tokens == 0:
|
||||
assert fp8_metadata.decode is not None
|
||||
attn_out = _fp8_decode(q, topk_indices)
|
||||
else:
|
||||
# Mixed or pure prefill: allocate output tensor
|
||||
attn_out = q.new_empty(
|
||||
(attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
if num_decode_tokens > 0:
|
||||
attn_out[:num_decode_tokens] = _fp8_decode(
|
||||
q[:num_decode_tokens], topk_indices[:num_decode_tokens]
|
||||
)
|
||||
|
||||
assert fp8_metadata.prefill is not None
|
||||
for chunk in fp8_metadata.prefill.chunks:
|
||||
chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
|
||||
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
||||
kv_c_and_k_pe_cache,
|
||||
chunk_workspace,
|
||||
chunk.block_table,
|
||||
chunk.seq_lens,
|
||||
chunk.workspace_starts,
|
||||
len(chunk.block_table),
|
||||
)
|
||||
|
||||
chunk_q = q[chunk.tokens_slice]
|
||||
chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]
|
||||
|
||||
attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
|
||||
chunk_q,
|
||||
chunk_workspace,
|
||||
chunk_topk_indices_workspace,
|
||||
)
|
||||
|
||||
return attn_out
|
||||
|
||||
def _forward_fp8_kv_mixed_batch(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Mixed batch FP8 forward path that treats all tokens as one batch.
|
||||
|
||||
This is equivalent to main branch's approach and avoids the BF16
|
||||
prefill kernel which has head padding overhead when num_heads is small.
|
||||
Used when use_mixed_batch is True.
|
||||
"""
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
)
|
||||
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
assert isinstance(
|
||||
attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
|
||||
)
|
||||
fp8_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = self._fp8_flash_mla_kernel(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
|
||||
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
|
||||
topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk)
|
||||
kernel_metadata=fp8_metadata,
|
||||
)
|
||||
|
||||
# Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
|
||||
return _attn_out.squeeze(0)
|
||||
|
||||
def _fp8_flash_mla_kernel(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# q shape: (batch, seq_len, num_heads, head_dim)
|
||||
actual_num_heads = q.size(2)
|
||||
padded_num_heads = self.fp8_decode_padded_heads
|
||||
|
||||
# Pad query if needed (kernel only supports h_q = 64 or 128)
|
||||
if actual_num_heads < padded_num_heads:
|
||||
logger.warning_once(
|
||||
f"Padding num_heads from {actual_num_heads} to "
|
||||
f"{padded_num_heads} for FP8 sparse decode kernel"
|
||||
)
|
||||
q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
|
||||
q_padded[:, :, :actual_num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
out, lse = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
||||
block_table=kernel_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=kernel_metadata.cache_lens,
|
||||
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
# Slice output back to actual head count if we padded
|
||||
if actual_num_heads < padded_num_heads:
|
||||
out = out[:, :, :actual_num_heads, :]
|
||||
|
||||
return out, lse
|
||||
|
||||
def _bf16_flash_mla_kernel(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.prefill_padding != 0:
|
||||
assert self.prefill_padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"Padding num_heads from {self.num_heads} to "
|
||||
f"{self.prefill_padding} for BF16 sparse prefill kernel"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_fwd(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)[0]
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
# Get topk indices
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
if not use_fp8_cache:
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
elif attn_metadata.fp8_use_mixed_batch:
|
||||
attn_out = self._forward_fp8_kv_mixed_batch(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_out = self._forward_fp8_kv_separate_prefill_decode(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
|
||||
return attn_out, None
|
||||
386
vllm/v1/attention/backends/mla/indexer.py
Normal file
386
vllm/v1/attention/backends/mla/indexer.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "DEEPSEEK_V32_INDEXER"
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [1 if current_platform.is_rocm() else 64]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 128]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
|
||||
return DeepseekV32IndexerMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
assert num_kv_heads == 1
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
if include_num_layers_dimension:
|
||||
return (0, 1, 2, 3)
|
||||
return (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
block_table: torch.Tensor
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
token_to_seq: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillMetadata:
|
||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
schedule_metadata: torch.Tensor
|
||||
use_large_context_topk: bool
|
||||
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerMetadata:
|
||||
# FIXME (zyongye)
|
||||
# hacky way to access the data now, need to be in chunked meta
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
decode: DeepSeekV32IndexerDecodeMetadata | None = None
|
||||
prefill: DeepseekV32IndexerPrefillMetadata | None = None
|
||||
|
||||
|
||||
# TODO (zyongye) optimize this, this is now vibe coded
|
||||
def kv_spans_from_batches(
|
||||
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
start_seq_loc: 1D long tensor [B+1], cumulative counts of
|
||||
selected tokens per batch.
|
||||
Example: [0, 2, 4, 7] ->
|
||||
batch sizes (selected) [2, 2, 3], N=7 tokens total.
|
||||
seq_len_per_batch: 1D long tensor [B],
|
||||
full sequence length (KV length) of each batch.
|
||||
Example: [5, 9, 4].
|
||||
|
||||
Returns:
|
||||
start_tensor: 1D long tensor [N], start offset in the
|
||||
concatenated KV cache for each token's batch.
|
||||
end_location: 1D long tensor [N],
|
||||
**exclusive** end = start + token's local position.
|
||||
(So the attended KV slice is kv[start:end].)
|
||||
|
||||
Assumes each batch contributes its full `seq_len_per_batch[i]`
|
||||
keys to the KV cache, andthe selected tokens within a batch
|
||||
are the **last** `counts[i]` positions of that sequence.
|
||||
"""
|
||||
q = start_seq_loc.to(dtype=torch.long)
|
||||
L = seq_len_per_batch.to(dtype=torch.long)
|
||||
assert q.dim() == 1 and L.dim() == 1
|
||||
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
|
||||
|
||||
# Selected tokens per batch and totals
|
||||
counts = q[1:] - q[:-1] # [B]
|
||||
N = int(q[-1].item()) # total selected tokens
|
||||
B = L.numel()
|
||||
|
||||
if N == 0:
|
||||
return (
|
||||
torch.empty(0, dtype=torch.long, device=device),
|
||||
torch.empty(0, dtype=torch.long, device=device),
|
||||
)
|
||||
|
||||
# KV start offsets per batch in the concatenated KV cache
|
||||
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
|
||||
|
||||
# For each selected token, which batch does it belong to?
|
||||
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
|
||||
|
||||
# Map batch KV start to each token
|
||||
start_tensor = kv_starts_per_batch[batch_id] # [N]
|
||||
|
||||
# End-align local positions inside each batch:
|
||||
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
|
||||
L_expand = torch.repeat_interleave(L, counts) # [N]
|
||||
m_expand = torch.repeat_interleave(counts, counts) # [N]
|
||||
# position within the selected block: 1..counts[b]
|
||||
pos_within = (
|
||||
torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
|
||||
)
|
||||
|
||||
local_pos = L_expand - m_expand + pos_within # [N], 1-based
|
||||
end_location = start_tensor + local_pos # exclusive end
|
||||
|
||||
return start_tensor.int().to(device), end_location.int().to(device)
|
||||
|
||||
|
||||
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
|
||||
# Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
|
||||
# The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
|
||||
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
|
||||
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
|
||||
# within the flashmla_sparse workspace.
|
||||
# For DeepSeek-V3.2, the max_model_len is 163840.
|
||||
# 40 * 163840 * 132 = 865075200 bytes = 825 MB
|
||||
return max_model_len * 40
|
||||
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
scheduler_config = self.vllm_config.scheduler_config
|
||||
# NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
|
||||
self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config)
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
if self.num_speculative_tokens > 1:
|
||||
raise ValueError(
|
||||
"Sparse MLA only supports "
|
||||
"num_speculative_tokens <= 1 because the DeepGEMM "
|
||||
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
|
||||
f"Got num_speculative_tokens={self.num_speculative_tokens}."
|
||||
)
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
self.scheduler_metadata_buffer = torch.empty(
|
||||
(self.num_sms + 1, 2), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
def build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
||||
- query_start_loc_cpu[reqs_start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(
|
||||
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
|
||||
).to(self.device)
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> DeepseekV32IndexerMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu[num_decodes:],
|
||||
self.max_prefill_buffer_size,
|
||||
request_offset=num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(
|
||||
common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes],
|
||||
)
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
batch_size = num_decodes
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
|
||||
else:
|
||||
offsets = None
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and has_deep_gemm():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||
# Padded CUDA graph requests have block_table entries of -1.
|
||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
||||
# This is safe because padded requests have seq_lens=0, so the
|
||||
# kernel produces no meaningful output for those rows.
|
||||
block_table.clamp_(min=0)
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=block_table,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
284
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
284
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [1]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: torch.Tensor | None = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: torch.Tensor | None = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: torch.Tensor | None = None
|
||||
# The query indptr, shape : [num_decode + 1]
|
||||
qo_indptr: torch.Tensor | None = None
|
||||
# The dtype of MLA out tensor
|
||||
attn_out_dtype: torch.dtype = torch.bfloat16
|
||||
# The max query output length: int
|
||||
max_qo_len: int | None = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
|
||||
)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_attn_out_dtype = vllm_config.model_config.dtype
|
||||
# kernel block size is always 1.
|
||||
max_num_pages_per_req = vllm_config.model_config.max_model_len
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
|
||||
# Preparing persistent buffers
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
|
||||
# paged_kv_last_page_len is always 1s (kernel block size is always 1),
|
||||
# so we create it once and reuse slices in both eager and cudagraph modes.
|
||||
self.paged_kv_last_page_len = torch.ones(
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
max_num_pages, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.qo_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> AiterMLADecodeMetadata:
|
||||
# kernel block size is always 1, although the kv block size is not 1.
|
||||
device = self.device
|
||||
num_reqs = seq_lens_device.size(0)
|
||||
|
||||
mask = torch.arange(
|
||||
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
|
||||
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
# kernel block size is always 1, so each page has exactly 1 token.
|
||||
# last_page_len is always 1 - just slice the pre-initialized buffer.
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
paged_kv_indptr = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
|
||||
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_qo_len = qo_len.max().item()
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(
|
||||
paged_kv_indices, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
|
||||
self.paged_kv_indptr[: 1 + num_reqs].copy_(
|
||||
paged_kv_indptr, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
|
||||
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
|
||||
|
||||
# paged_kv_last_page_len already uses the pre-initialized buffer slice
|
||||
# (set above), so no copy needed - buffer is always 1s.
|
||||
|
||||
self.qo_indptr[: 1 + num_reqs].copy_(
|
||||
query_start_loc_device, non_blocking=True
|
||||
)
|
||||
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
|
||||
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
||||
|
||||
else:
|
||||
qo_indptr = torch.arange(
|
||||
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_qo_len=max_qo_len,
|
||||
attn_out_dtype=self.decode_attn_out_dtype,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
assert num_heads == 16 or num_heads == 128, (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value."
|
||||
)
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
output = self.flash_attn_varlen_func( # type: ignore[call-arg]
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
assert attn_metadata.decode.max_qo_len is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
o = torch.zeros(
|
||||
B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=attn_metadata.decode.attn_out_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer,
|
||||
o,
|
||||
self.scale,
|
||||
attn_metadata.decode.qo_indptr,
|
||||
attn_metadata.decode.max_qo_len,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len,
|
||||
q_scale=layer._q_scale,
|
||||
kv_scale=layer._k_scale,
|
||||
)
|
||||
|
||||
return o, None
|
||||
368
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
Normal file
368
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
Normal file
@@ -0,0 +1,368 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fetch_id_to_ragged_kernel(
|
||||
in_tensor_ptr, # [num_seq, topk]
|
||||
cumsum_ptr, # [num_seq + 1]
|
||||
out_tensor_ptr, # [max_num_seq * topk]
|
||||
in_tensor_ptr_stride,
|
||||
TOPK: tl.constexpr,
|
||||
TOKEN_NUM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
seq_id = tl.program_id(0)
|
||||
block_id = tl.program_id(1)
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
token_start = tl.load(cumsum_ptr + seq_id)
|
||||
token_end = tl.load(cumsum_ptr + seq_id + 1)
|
||||
token_num = token_end - token_start
|
||||
row_offset = block_id * BLOCK_SIZE
|
||||
if row_offset >= token_num:
|
||||
return
|
||||
in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset
|
||||
in_tensor_mask = (row_offset + offset) < TOPK
|
||||
in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask)
|
||||
out_tensor_offset = token_start + row_offset + offset
|
||||
out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask
|
||||
tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask)
|
||||
|
||||
|
||||
def fetch_id_to_ragged_triton(
|
||||
in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk
|
||||
):
|
||||
num_tokens = in_tensor.size(0)
|
||||
block_size = 64
|
||||
num_block_per_row = triton.cdiv(topk, block_size)
|
||||
grid = (
|
||||
num_tokens,
|
||||
num_block_per_row,
|
||||
)
|
||||
fetch_id_to_ragged_kernel[grid](
|
||||
in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size
|
||||
)
|
||||
|
||||
|
||||
class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["ROCMAiterMLASparseMetadata"]:
|
||||
return ROCMAiterMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
|
||||
return ROCMAiterMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
|
||||
return ROCMAiterMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCMAiterMLASparseMetadata(AttentionMetadata):
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
|
||||
qo_indptr: torch.Tensor
|
||||
paged_kv_last_page_len: torch.Tensor
|
||||
paged_kv_indices: torch.Tensor
|
||||
paged_kv_indptr: torch.Tensor
|
||||
paged_kv_indptr_rest: torch.Tensor
|
||||
|
||||
block_size: int = 1
|
||||
topk_tokens: int = 2048
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCMAiterMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
|
||||
):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.topk_tokens_tensor = torch.tensor(
|
||||
[self.topk_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
self.max_model_len_tensor = torch.tensor(
|
||||
[self.model_config.max_model_len], device=device, dtype=torch.int32
|
||||
)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty(
|
||||
(1, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.qo_indptr = torch.arange(
|
||||
0, max_num_batched_tokens + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_last_page_len = torch.ones(
|
||||
max_num_batched_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# These two needs to be calculated in runtime,
|
||||
# but we still needs to prepare the buffer
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
[max_num_batched_tokens * self.topk_tokens],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
[max_num_batched_tokens + 1], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> ROCMAiterMLASparseMetadata:
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
self.paged_kv_indices.fill_(0)
|
||||
self.paged_kv_indptr.fill_(0)
|
||||
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
qo_indptr = self.qo_indptr[: num_tokens + 1]
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens]
|
||||
paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens]
|
||||
paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1]
|
||||
paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :]
|
||||
|
||||
metadata = ROCMAiterMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
qo_indptr=qo_indptr,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indptr_rest=paged_kv_indptr_rest,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
# Take from
|
||||
# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_prefill.py#L72
|
||||
def reference_mla_sparse_prefill(
|
||||
q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import math
|
||||
|
||||
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
|
||||
|
||||
skv = kv.shape[0]
|
||||
sq = q.shape[0]
|
||||
topk = indices.shape[-1]
|
||||
dqk = q.shape[-1]
|
||||
indices = indices[:, 0, :] # [s_q, topk]
|
||||
invalid_indices_mask = (indices < 0) | (indices >= skv)
|
||||
indices[invalid_indices_mask] = 0
|
||||
qs = q # [s_q, h_q, d_qk]
|
||||
kvs = kv[:, 0, :][indices].view(sq, topk, dqk) # [s_q, topk, d_qk]
|
||||
|
||||
attn_score = (qs @ kvs.transpose(1, 2)).float() # [s_q, h_q, topk]
|
||||
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
|
||||
attn_score *= sm_scale * math.log2(math.e)
|
||||
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
|
||||
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
|
||||
result = attn_score.to(q.dtype) @ kvs[:, :, :d_v]
|
||||
return (result, lse)
|
||||
|
||||
|
||||
class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor, # [sq, heads, d_qk]
|
||||
kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
|
||||
topk_indices: torch.Tensor, # [sq, topk]
|
||||
attn_metadata: ROCMAiterMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
output = torch.empty(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
seq_len = (topk_indices != -1).sum(dim=-1)
|
||||
torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
|
||||
attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])
|
||||
fetch_id_to_ragged_triton(
|
||||
topk_indices,
|
||||
attn_metadata.paged_kv_indptr,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
output,
|
||||
self.scale,
|
||||
attn_metadata.qo_indptr,
|
||||
1,
|
||||
attn_metadata.paged_kv_indptr,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len,
|
||||
)
|
||||
|
||||
return output[:, : self.num_heads, :]
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: ROCMAiterMLASparseMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
# Get topk indices
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
|
||||
return attn_out, None
|
||||
191
vllm/v1/attention/backends/mla/sparse_utils.py
Normal file
191
vllm/v1/attention/backends/mla/sparse_utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for sparse MLA backends."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Kernel with prefill workspace support and valid count tracking
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
valid_count_ptr, # int32 [num_tokens] - output valid count per row
|
||||
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
|
||||
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
HAS_PREFILL: tl.constexpr,
|
||||
COUNT_VALID: tl.constexpr, # whether to count valid indices
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
is_prefill = False
|
||||
if HAS_PREFILL:
|
||||
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
|
||||
is_prefill = prefill_req_id >= 0
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
is_invalid_tok |= ~valid_block
|
||||
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
|
||||
out_val = base * BLOCK_SIZE + inblock_off
|
||||
|
||||
# Override with prefill output if prefill is enabled
|
||||
if HAS_PREFILL:
|
||||
workspace_start = tl.load(
|
||||
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
|
||||
)
|
||||
prefill_out = workspace_start + tok
|
||||
out_val = tl.where(is_prefill, prefill_out, out_val)
|
||||
out_val = tl.where(is_invalid_tok, -1, out_val)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
# Count valid indices in this tile and atomically add to row total
|
||||
if COUNT_VALID:
|
||||
tile_valid_count = tl.sum((~is_invalid_tok).to(tl.int32))
|
||||
tl.atomic_add(valid_count_ptr + token_id, tile_valid_count)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
return_valid_counts: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
|
||||
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
|
||||
instead of global cache slots. prefill_workspace_request_ids and
|
||||
prefill_workspace_starts must be provided.
|
||||
|
||||
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
|
||||
prefill request index (maps to prefill_workspace_starts)
|
||||
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
|
||||
starts for each prefill request
|
||||
|
||||
When return_valid_counts is True, also returns the count of valid (non -1)
|
||||
indices per row, computed during the same kernel pass (no extra overhead).
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None
|
||||
assert prefill_workspace_starts is not None
|
||||
assert prefill_workspace_request_ids.dtype == torch.int32
|
||||
assert prefill_workspace_starts.dtype == torch.int32
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
max_num_blocks_per_req = block_table.shape[1]
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Allocate valid count buffer if needed (must be zero-initialized for atomics)
|
||||
valid_counts: torch.Tensor | None = None
|
||||
if return_valid_counts:
|
||||
valid_counts = torch.zeros(
|
||||
num_tokens, dtype=torch.int32, device=token_indices.device
|
||||
)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Prepare prefill pointers
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None # for mypy
|
||||
assert prefill_workspace_starts is not None # for mypy
|
||||
assert prefill_workspace_request_ids.is_contiguous()
|
||||
assert prefill_workspace_starts.is_contiguous()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
valid_counts,
|
||||
prefill_workspace_request_ids,
|
||||
prefill_workspace_starts,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
HAS_PREFILL_WORKSPACE,
|
||||
return_valid_counts,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
|
||||
if return_valid_counts:
|
||||
assert valid_counts is not None
|
||||
return out, valid_counts
|
||||
return out
|
||||
210
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
210
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
# "auto",
|
||||
# "bfloat16",
|
||||
# ]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl"
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
# layer: AttentionLayer,
|
||||
k_c_normed: torch.Tensor |None = None,
|
||||
k_pe: torch.Tensor |None = None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode
|
||||
q_nope = self._k_up_proj(q_nope)
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q = get_dcp_group().all_gather(q, dim=1)
|
||||
o = torch.empty(B,
|
||||
q.shape[1],
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
else:
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
return_softmax_lse=True)
|
||||
return attn_out, softmax_lse
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph
|
||||
)
|
||||
return self._v_up_proj(o), None
|
||||
261
vllm/v1/attention/backends/registry.py
Normal file
261
vllm/v1/attention/backends/registry.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend registry"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import Enum, EnumMeta
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _AttentionBackendEnumMeta(EnumMeta):
|
||||
"""Metaclass for AttentionBackendEnum to provide better error messages."""
|
||||
|
||||
def __getitem__(cls, name: str):
|
||||
"""Get backend by name with helpful error messages."""
|
||||
try:
|
||||
return super().__getitem__(name)
|
||||
except KeyError:
|
||||
members = cast("dict[str, Enum]", cls.__members__).keys()
|
||||
valid_backends = ", ".join(members)
|
||||
raise ValueError(
|
||||
f"Unknown attention backend: '{name}'. "
|
||||
f"Valid options are: {valid_backends}"
|
||||
) from None
|
||||
|
||||
|
||||
class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""Enumeration of all supported attention backends.
|
||||
|
||||
The enum value is the default class path, but this can be overridden
|
||||
at runtime using register_backend().
|
||||
|
||||
To get the actual backend class (respecting overrides), use:
|
||||
backend.get_class()
|
||||
"""
|
||||
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
FLASH_ATTN_DIFFKV = (
|
||||
"vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend"
|
||||
)
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||
ROCM_AITER_TRITON_MLA = (
|
||||
"vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend"
|
||||
)
|
||||
ROCM_AITER_FA = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
)
|
||||
ROCM_AITER_MLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
|
||||
)
|
||||
TORCH_SDPA = "" # this tag is only used for ViT
|
||||
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
FLASHINFER_MLA = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
FLASHINFER_MLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla_sparse."
|
||||
"FlashInferMLASparseBackend"
|
||||
)
|
||||
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
FLASHMLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
||||
)
|
||||
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
|
||||
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
|
||||
ROCM_AITER_UNIFIED_ATTN = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
|
||||
"RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
# set to None to avoid alias with other backend, whose value is an empty string
|
||||
CUSTOM = None
|
||||
|
||||
def get_path(self, include_classname: bool = True) -> str:
|
||||
"""Get the class path for this backend (respects overrides).
|
||||
|
||||
Returns:
|
||||
The fully qualified class path string
|
||||
|
||||
Raises:
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
path = _ATTN_OVERRIDES.get(self, self.value)
|
||||
if not path:
|
||||
raise ValueError(
|
||||
f"Backend {self.name} must be registered before use. "
|
||||
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
|
||||
)
|
||||
if not include_classname:
|
||||
path = path.rsplit(".", 1)[0]
|
||||
return path
|
||||
|
||||
def get_class(self) -> "type[AttentionBackend]":
|
||||
"""Get the backend class (respects overrides).
|
||||
|
||||
Returns:
|
||||
The backend class
|
||||
|
||||
Raises:
|
||||
ImportError: If the backend class cannot be imported
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
return resolve_obj_by_qualname(self.get_path())
|
||||
|
||||
def is_overridden(self) -> bool:
|
||||
"""Check if this backend has been overridden.
|
||||
|
||||
Returns:
|
||||
True if the backend has a registered override
|
||||
"""
|
||||
return self in _ATTN_OVERRIDES
|
||||
|
||||
def clear_override(self) -> None:
|
||||
"""Clear any override for this backend, reverting to the default."""
|
||||
_ATTN_OVERRIDES.pop(self, None)
|
||||
|
||||
|
||||
class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""Enumeration of all supported mamba attention backends.
|
||||
|
||||
The enum value is the default class path, but this can be overridden
|
||||
at runtime using register_backend().
|
||||
|
||||
To get the actual backend class (respecting overrides), use:
|
||||
backend.get_class()
|
||||
"""
|
||||
|
||||
MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
|
||||
MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
|
||||
SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
|
||||
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
|
||||
GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
# set to None to avoid alias with other backend, whose value is an empty string
|
||||
CUSTOM = None
|
||||
|
||||
def get_path(self, include_classname: bool = True) -> str:
|
||||
"""Get the class path for this backend (respects overrides).
|
||||
|
||||
Returns:
|
||||
The fully qualified class path string
|
||||
|
||||
Raises:
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
path = _MAMBA_ATTN_OVERRIDES.get(self, self.value)
|
||||
if not path:
|
||||
raise ValueError(
|
||||
f"Backend {self.name} must be registered before use. "
|
||||
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
|
||||
)
|
||||
if not include_classname:
|
||||
path = path.rsplit(".", 1)[0]
|
||||
return path
|
||||
|
||||
def get_class(self) -> "type[AttentionBackend]":
|
||||
"""Get the backend class (respects overrides).
|
||||
|
||||
Returns:
|
||||
The backend class
|
||||
|
||||
Raises:
|
||||
ImportError: If the backend class cannot be imported
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
return resolve_obj_by_qualname(self.get_path())
|
||||
|
||||
def is_overridden(self) -> bool:
|
||||
"""Check if this backend has been overridden.
|
||||
|
||||
Returns:
|
||||
True if the backend has a registered override
|
||||
"""
|
||||
return self in _MAMBA_ATTN_OVERRIDES
|
||||
|
||||
def clear_override(self) -> None:
|
||||
"""Clear any override for this backend, reverting to the default."""
|
||||
_MAMBA_ATTN_OVERRIDES.pop(self, None)
|
||||
|
||||
|
||||
MAMBA_TYPE_TO_BACKEND_MAP = {
|
||||
"mamba1": MambaAttentionBackendEnum.MAMBA1.name,
|
||||
"mamba2": MambaAttentionBackendEnum.MAMBA2.name,
|
||||
"short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
|
||||
"linear_attention": MambaAttentionBackendEnum.LINEAR.name,
|
||||
"gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
|
||||
"custom": MambaAttentionBackendEnum.CUSTOM.name,
|
||||
}
|
||||
|
||||
|
||||
_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
|
||||
_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
|
||||
|
||||
|
||||
def register_backend(
|
||||
backend: AttentionBackendEnum | MambaAttentionBackendEnum,
|
||||
class_path: str | None = None,
|
||||
is_mamba: bool = False,
|
||||
) -> Callable[[type], type]:
|
||||
"""Register or override a backend implementation.
|
||||
|
||||
Args:
|
||||
backend: The AttentionBackendEnum member to register
|
||||
class_path: Optional class path. If not provided and used as
|
||||
decorator, will be auto-generated from the class.
|
||||
|
||||
Returns:
|
||||
Decorator function if class_path is None, otherwise a no-op
|
||||
|
||||
Examples:
|
||||
# Override an existing attention backend
|
||||
@register_backend(AttentionBackendEnum.FLASH_ATTN)
|
||||
class MyCustomFlashAttn:
|
||||
...
|
||||
|
||||
# Override an existing mamba attention backend
|
||||
@register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
|
||||
class MyCustomMambaAttn:
|
||||
...
|
||||
|
||||
# Register a custom third-party attention backend
|
||||
@register_backend(AttentionBackendEnum.CUSTOM)
|
||||
class MyCustomBackend:
|
||||
...
|
||||
|
||||
# Direct registration
|
||||
register_backend(
|
||||
AttentionBackendEnum.CUSTOM,
|
||||
"my.module.MyCustomBackend"
|
||||
)
|
||||
"""
|
||||
|
||||
def decorator(cls: type) -> type:
|
||||
if is_mamba:
|
||||
_MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
|
||||
else:
|
||||
_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
|
||||
return cls
|
||||
|
||||
if class_path is not None:
|
||||
if is_mamba:
|
||||
_MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
|
||||
else:
|
||||
_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
|
||||
return lambda x: x
|
||||
|
||||
return decorator
|
||||
1336
vllm/v1/attention/backends/rocm_aiter_fa.py
Normal file
1336
vllm/v1/attention/backends/rocm_aiter_fa.py
Normal file
File diff suppressed because it is too large
Load Diff
249
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
249
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionLayer, AttentionType
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_attn import (
|
||||
RocmAttentionBackend,
|
||||
RocmAttentionImpl,
|
||||
RocmAttentionMetadataBuilder,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_UNIFIED_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
|
||||
return RocmAiterUnifiedAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
sinks,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
|
||||
)
|
||||
from aiter.ops.triton.unified_attention import unified_attention
|
||||
|
||||
self.unified_attention = unified_attention
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
descale_shape = (
|
||||
cu_seqlens_q.shape[0] - 1,
|
||||
key.shape[1] if key is not None else self.num_kv_heads,
|
||||
)
|
||||
|
||||
self.unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
return rocm_aiter_ops.is_enabled()
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
flash_layout = True
|
||||
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
|
||||
rocm_aiter_ops.triton_rope_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
key_cache,
|
||||
value_cache,
|
||||
layer_slot_mapping,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
flash_layout,
|
||||
is_fp8_kv_cache,
|
||||
)
|
||||
461
vllm/v1/attention/backends/rocm_attn.py
Normal file
461
vllm/v1/attention/backends/rocm_attn.py
Normal file
@@ -0,0 +1,461 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode,
|
||||
)
|
||||
from vllm.v1.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RocmAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
|
||||
|
||||
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> RocmAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
|
||||
# Here we set the query start locs to 0. This is to
|
||||
# cover up an invalid memory access in the prefix_prefil kernel
|
||||
# that we run into during graph capture (#25985)
|
||||
common_attn_metadata.query_start_loc.zero_()
|
||||
common_attn_metadata.query_start_loc_cpu.zero_()
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> RocmAttentionMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
||||
)
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
|
||||
suffix_kv_lens = suffix_kv_lens.to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
attn_metadata = RocmAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class RocmAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
# ROCM paged attention kernel only supports block sizes 16 and 32
|
||||
# due to shared memory (LDS) constraints on AMD GPUs.
|
||||
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.
|
||||
|
||||
# However, The limitations in [16, 32] are reasonable for a native C++ kernel,
|
||||
# but vLLM should allow support for non-standard sizes via the Triton path,
|
||||
# as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380,
|
||||
# where the Triton kernel under rocm_atten does not support inference
|
||||
# for a non-standard qwen3-next model with a block_size of 544.
|
||||
# We have fixed the Triton kernel so that the standard model uses the original
|
||||
# bit-addressing logic, while the non-standard model
|
||||
# uses our optimized kernel logic.
|
||||
return [16, 32, 544]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
if not cls.supports_head_size(head_size):
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
||||
"Set --attention-backend=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||
return RocmAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
|
||||
class RocmAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
RocmAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention is not implemented for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens] if key is not None else None,
|
||||
value=value[:num_actual_tokens] if value is not None else None,
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
output_scale=output_scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Get the actual block_size from value_cache
|
||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# Determine if it is a power of 2
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
|
||||
if is_pow2:
|
||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
||||
# force using our modified Triton logic
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
return rocm_aiter_ops.is_enabled()
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache,
|
||||
layer.num_kv_heads, # type: ignore[attr-defined]
|
||||
layer.head_size, # type: ignore[attr-defined]
|
||||
)
|
||||
flash_layout = False
|
||||
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
|
||||
rocm_aiter_ops.triton_rope_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
key_cache,
|
||||
value_cache,
|
||||
layer_slot_mapping,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
flash_layout,
|
||||
is_fp8_kv_cache,
|
||||
)
|
||||
30
vllm/v1/attention/backends/short_conv_attn.py
Normal file
30
vllm/v1/attention/backends/short_conv_attn.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
)
|
||||
|
||||
|
||||
class ShortConvAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SHORT_CONV_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
|
||||
return ShortConvAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):
|
||||
pass
|
||||
|
||||
|
||||
class ShortConvAttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
|
||||
):
|
||||
metadata_cls = ShortConvAttentionMetadata
|
||||
430
vllm/v1/attention/backends/tree_attn.py
Normal file
430
vllm/v1/attention/backends/tree_attn.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with TreeAttention."""
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TREE_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||
return TreeAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
|
||||
return TreeAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TreeAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
tree_attn_bias: torch.Tensor | None = None
|
||||
|
||||
# Cached Prefill/decode metadata.
|
||||
_cached_prefill_metadata: "TreeAttentionMetadata | None" = None
|
||||
_cached_decode_metadata: "TreeAttentionMetadata | None" = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> "TreeAttentionMetadata | None":
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[self.num_decodes :]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[self.num_decodes :]
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_prefill_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc - q_start_loc[0],
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[self.num_decodes :],
|
||||
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> "TreeAttentionMetadata | None":
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[: self.num_decodes + 1]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[: self.num_decodes]
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_decode_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[: self.num_decodes],
|
||||
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
spec_config = vllm_config.speculative_config
|
||||
spec_token_tree: str | None = None
|
||||
if spec := spec_config:
|
||||
spec_token_tree = spec.speculative_token_tree
|
||||
tree_choices: list[tuple[int, ...]] = (
|
||||
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
|
||||
)
|
||||
# Construct the tree attention bias.
|
||||
depth_counts = _get_depth_counts(tree_choices)
|
||||
self.tree_attn_bias = _prepare_tree_attn_bias(
|
||||
tree_choices,
|
||||
depth_counts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> TreeAttentionMetadata:
|
||||
decode_threshold = self.tree_attn_bias.shape[0]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=decode_threshold
|
||||
)
|
||||
)
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
return TreeAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> TreeAttentionMetadata:
|
||||
# Cache the original tree attention bias.
|
||||
orig_tree_attn_bias = self.tree_attn_bias
|
||||
|
||||
if draft_index == 0:
|
||||
# Use prefill for drafting at the root level.
|
||||
self.tree_attn_bias = torch.empty(0)
|
||||
else:
|
||||
# Slice the tree attention bias for drafting. Exclude
|
||||
# the root level.
|
||||
start, end = 1, 1 + common_attn_metadata.max_query_len
|
||||
self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous()
|
||||
|
||||
# Build attention bias.
|
||||
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
|
||||
|
||||
# Reset the tree attention bias to the original value.
|
||||
self.tree_attn_bias = orig_tree_attn_bias
|
||||
return attn_metadata
|
||||
|
||||
|
||||
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
|
||||
# Count the number of choices at each depth of the tree.
|
||||
depth_counts = []
|
||||
prev_depth = 0
|
||||
for path in sorted_tree_choices:
|
||||
depth = len(path)
|
||||
if depth != prev_depth:
|
||||
depth_counts.append(0)
|
||||
depth_counts[depth - 1] += 1
|
||||
prev_depth = depth
|
||||
return depth_counts
|
||||
|
||||
|
||||
def _prepare_tree_attn_bias(
|
||||
sorted_tree_choices: list[tuple[int, ...]],
|
||||
depth_counts: list[int],
|
||||
dtype: torch.dtype | None,
|
||||
device: torch.device | None,
|
||||
) -> torch.Tensor:
|
||||
# +1 comes from the additional root node.
|
||||
tree_len = len(sorted_tree_choices) + 1
|
||||
tree_attn_mask = torch.full(
|
||||
(tree_len, tree_len), -torch.inf, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Set diagonal to all zeros. Each token should
|
||||
# attend to itself.
|
||||
mask_val = 0
|
||||
for i in range(tree_len):
|
||||
tree_attn_mask[i, i] = mask_val
|
||||
|
||||
# Set root to all zeros. All tokens attend to it.
|
||||
tree_attn_mask[:, 0] = mask_val
|
||||
|
||||
# Set all ancestors to zeros.
|
||||
start = 0
|
||||
for i in range(len(depth_counts)):
|
||||
for j in range(depth_counts[i]):
|
||||
cur_tree_choice = sorted_tree_choices[start + j]
|
||||
# Retrieve ancestor position.
|
||||
if len(cur_tree_choice) == 1:
|
||||
continue
|
||||
ancestor_idx = []
|
||||
for c in range(len(cur_tree_choice) - 1):
|
||||
ancestor_idx.append(
|
||||
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
|
||||
)
|
||||
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
|
||||
start += depth_counts[i]
|
||||
return tree_attn_mask
|
||||
|
||||
|
||||
class TreeAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if logits_soft_cap is None:
|
||||
# Setting logits_soft_cap to 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TreeAttentionImpl."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TreeAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with TreeAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for TreeAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1])
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
unified_attention(
|
||||
q=query[num_decode_tokens:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:num_actual_tokens],
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
seqused_k=prefill_meta.seq_lens,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=prefill_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
unified_attention(
|
||||
q=query[:num_decode_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_decode_tokens],
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_query_len,
|
||||
seqused_k=decode_meta.seq_lens,
|
||||
max_seqlen_k=decode_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
qq_bias=decode_meta.tree_attn_bias,
|
||||
window_size=self.sliding_window,
|
||||
block_table=decode_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
638
vllm/v1/attention/backends/triton_attn.py
Normal file
638
vllm/v1/attention/backends/triton_attn.py
Normal file
@@ -0,0 +1,638 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""High-Performance Triton-only Attention layer."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# constants
|
||||
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
|
||||
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
seq_threshold_3D: int
|
||||
num_par_softmax_segments: int
|
||||
softmax_segm_output: torch.Tensor
|
||||
softmax_segm_max: torch.Tensor
|
||||
softmax_segm_expsum: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
|
||||
@property
|
||||
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
|
||||
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
|
||||
|
||||
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
|
||||
Empty ranges have start==end==0, which kernel skips via is_valid check.
|
||||
"""
|
||||
# TODO(Isotr0py): Move to model runner's attention metadata
|
||||
# preparation to avoid duplicate computation.
|
||||
if self.mm_prefix_range is None:
|
||||
return None
|
||||
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
device = self.seq_lens.device
|
||||
|
||||
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
|
||||
range_lists = [
|
||||
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
|
||||
]
|
||||
|
||||
# Return None if all ranges are trivial (only (0,0) placeholders)
|
||||
if all(r == [(0, 0)] for r in range_lists):
|
||||
return None
|
||||
|
||||
# Create 2D tensors with shape (num_ranges, 2) for each sequence
|
||||
range_tensors = [
|
||||
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
|
||||
for r in range_lists
|
||||
]
|
||||
|
||||
return torch.nested.nested_tensor(
|
||||
range_tensors, layout=torch.jagged
|
||||
).to_padded_tensor(0)
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
# Check if CUDA Graphs are enabled for decode
|
||||
self.decode_cudagraph_enabled = (
|
||||
self.vllm_config.compilation_config.cudagraph_mode
|
||||
in (
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
CUDAGraphMode.FULL_DECODE_ONLY,
|
||||
CUDAGraphMode.FULL,
|
||||
)
|
||||
)
|
||||
|
||||
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
|
||||
# A lower bound for num_q_blocks is the number of sequences.
|
||||
# To ensure the minimum launch grid size is achieved, the number of sequences
|
||||
# must be at least equal to the threshold below.
|
||||
# If this threshold is not reached (i.e., the batch size is not large enough),
|
||||
# the 3D kernel will be selected instead.
|
||||
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
|
||||
|
||||
# Modify the threshold if needed.
|
||||
if self.decode_cudagraph_enabled:
|
||||
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified."
|
||||
|
||||
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
|
||||
# as threshold. This ensures that each captured graph covers the
|
||||
# correct execution path.
|
||||
self.seq_threshold_3D = min(
|
||||
capture_sizes,
|
||||
key=lambda x: abs(x - self.seq_threshold_3D),
|
||||
)
|
||||
|
||||
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
|
||||
headdim_padded = next_power_of_2(self.headdim)
|
||||
self.softmax_segm_output = torch.empty(
|
||||
(
|
||||
self.seq_threshold_3D,
|
||||
self.num_heads_q,
|
||||
self.num_par_softmax_segments,
|
||||
headdim_padded,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.softmax_segm_max = torch.empty(
|
||||
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.softmax_segm_expsum = torch.empty(
|
||||
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> TritonAttentionMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
||||
)
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
|
||||
suffix_kv_lens = suffix_kv_lens.to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
attn_metadata = TritonAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
seq_threshold_3D=self.seq_threshold_3D,
|
||||
num_par_softmax_segments=self.num_par_softmax_segments,
|
||||
softmax_segm_output=self.softmax_segm_output,
|
||||
softmax_segm_max=self.softmax_segm_max,
|
||||
softmax_segm_expsum=self.softmax_segm_expsum,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||
return TritonAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
if include_num_layers_dimension:
|
||||
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
||||
return (1, 0, 2, 3, 4, 5)
|
||||
|
||||
# (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
return (0, 1, 2, 3, 4)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""TritonAttention supports all attention types."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_alibi_sqrt(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
use_alibi_sqrt: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
elif attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):
|
||||
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
self.use_alibi_sqrt = use_alibi_sqrt
|
||||
self.supports_quant_query_input = current_platform.is_cuda()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TritonAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Paged Attention impl. in Triton.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for TritonAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
seq_threshold_3D = attn_metadata.seq_threshold_3D
|
||||
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
|
||||
softmax_segm_output = attn_metadata.softmax_segm_output
|
||||
softmax_segm_max = attn_metadata.softmax_segm_max
|
||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
use_alibi_sqrt=self.use_alibi_sqrt,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
mm_prefix_range=mm_prefix_range_tensor,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: TritonAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for encoder attention without KV cache.
|
||||
|
||||
Args:
|
||||
query: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"quantization is not supported for encoder attention"
|
||||
)
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
# Call flash attention directly on Q, K, V tensors
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_input_len=max_query_len,
|
||||
is_causal=False, # Encoder attention is bidirectional
|
||||
softmax_scale=self.scale,
|
||||
sliding_window_q=self.sliding_window[0],
|
||||
sliding_window_k=self.sliding_window[1],
|
||||
)
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
return rocm_aiter_ops.is_enabled()
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
flash_layout = True
|
||||
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
|
||||
rocm_aiter_ops.triton_rope_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
key_cache,
|
||||
value_cache,
|
||||
layer_slot_mapping,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
flash_layout,
|
||||
is_fp8_kv_cache,
|
||||
)
|
||||
866
vllm/v1/attention/backends/utils.py
Normal file
866
vllm/v1/attention/backends/utils.py
Normal file
@@ -0,0 +1,866 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field, fields, make_dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
get_args,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def is_valid_kv_cache_layout(value: str) -> bool:
|
||||
return value in get_args(KVCacheLayoutType)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
# Format specified by the code.
|
||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||
|
||||
cache_layout: Literal["NHD", "HND"] | None = None
|
||||
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
||||
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
||||
logger.info_once(
|
||||
"`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
|
||||
"Setting KV cache layout to %s.",
|
||||
cache_layout,
|
||||
)
|
||||
return cache_layout
|
||||
|
||||
# Format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
# When neither the user nor the override specified a layout, get default
|
||||
if cache_layout is None:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
else:
|
||||
assert is_valid_kv_cache_layout(cache_layout)
|
||||
logger.info_once(
|
||||
"`VLLM_KV_CACHE_LAYOUT` environment variable "
|
||||
"detected. Setting KV cache layout to %s.",
|
||||
cache_layout,
|
||||
)
|
||||
return cache_layout
|
||||
|
||||
|
||||
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
|
||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters. Should not be used for
|
||||
trtllm-gen backend since it supports different values for the following
|
||||
hyperparameters.
|
||||
"""
|
||||
|
||||
window_left: int
|
||||
logits_soft_cap: float | None
|
||||
sm_scale: float
|
||||
has_sinks: bool = False
|
||||
# has same params for all layers
|
||||
has_same_window_lefts: bool | None = field(default=None, compare=False)
|
||||
has_same_all_params: bool | None = field(default=None, compare=False)
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
|
||||
) -> dict[str, PerLayerParameters]:
|
||||
"""
|
||||
Scan layers in `layer_names` and determine some hyperparameters
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(
|
||||
vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
layer_names,
|
||||
)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, cls_)
|
||||
|
||||
# Infer hyperparameters from the attention layer
|
||||
window_size = getattr(impl, "sliding_window", None)
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
|
||||
sm_scale = impl.scale
|
||||
has_sinks = getattr(impl, "sinks", None) is not None
|
||||
|
||||
per_layer_params[key] = PerLayerParameters(
|
||||
window_left, logits_soft_cap, sm_scale, has_sinks
|
||||
)
|
||||
|
||||
return per_layer_params
|
||||
|
||||
|
||||
def infer_global_hyperparameters(
|
||||
per_layer_params: dict[str, PerLayerParameters],
|
||||
) -> PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend other than trtllm-gen
|
||||
only support models in which all layers share
|
||||
the same values for the following hyperparameters:
|
||||
- `window_left`
|
||||
- `logits_soft_cap`
|
||||
- `sm_scale`
|
||||
|
||||
So this function asserts that all layers share the same values for these
|
||||
hyperparameters and returns the global values.
|
||||
"""
|
||||
|
||||
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
||||
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
|
||||
global_params.has_same_window_lefts = all(
|
||||
params.window_left == global_params.window_left for params in param_sets
|
||||
)
|
||||
global_params.has_same_all_params = all(
|
||||
params == global_params for params in param_sets
|
||||
)
|
||||
|
||||
return global_params
|
||||
|
||||
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
# as an independent local ("virtual") batch item.
|
||||
#
|
||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# kv_seqlens = [6, 17, 9]
|
||||
# Then normally for regular attention we would compute with an attention mask
|
||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1 1 1 1 1
|
||||
# 3 | 1 1 1 1 1 1
|
||||
#
|
||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||
# attention mask like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# We can simulate this mask using standard flash-attention by breaking the
|
||||
# sequences into local ("virtual") batches, where each local batch item is a
|
||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||
#
|
||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||
# k_toks > 0 1 2 3
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||
# k_toks > 4 5
|
||||
# q_toks v _____________
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# e.g. if we have:
|
||||
# attn_chunk_size = 4
|
||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||
# Then this function would return:
|
||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
block_size: int = 0,
|
||||
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
|
||||
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
|
||||
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
||||
).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
||||
)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
|
||||
np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
|
||||
cu_seqlens_q_local[0] = 0
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
||||
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
||||
)
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // block_size
|
||||
assert attn_chunk_size % block_size == 0, (
|
||||
f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}"
|
||||
)
|
||||
pages_per_local_batch = attn_chunk_size // block_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming block_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices = block_starts[:, None] + np.arange(
|
||||
pages_per_local_batch, dtype=np.int32
|
||||
)
|
||||
block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
|
||||
batch_indices = np.repeat(
|
||||
np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch,
|
||||
)
|
||||
|
||||
# NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
|
||||
# regression when using numpy arrays (batch and block indices) to index into
|
||||
# torch tensor (block_table). As a workaround, convert numpy arrays to torch
|
||||
# tensor first, which recovers perf.
|
||||
batch_indices_torch = torch.from_numpy(batch_indices)
|
||||
block_indices_torch = torch.from_numpy(block_indices)
|
||||
|
||||
# Save as a lambda so we can return this for update_block_table
|
||||
make_block_table = lambda block_table: block_table[
|
||||
batch_indices_torch, block_indices_torch
|
||||
].view(virtual_batches, -1)
|
||||
block_table_local = make_block_table(block_table)
|
||||
|
||||
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
||||
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
|
||||
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
|
||||
num_reqs=len(seq_lens_cpu),
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
max_query_len=seqlens_q_local.max(),
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_local,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||
), make_block_table
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> CommonAttentionMetadata:
|
||||
if common_attn_metadata.max_query_len == 1:
|
||||
# All requests are decode (assume 1 token for now)
|
||||
# Skip computing fast prefill path
|
||||
return common_attn_metadata
|
||||
|
||||
assert common_attn_metadata.logits_indices_padded is not None
|
||||
assert common_attn_metadata.num_logits_indices is not None
|
||||
|
||||
logits_indices_padded = common_attn_metadata.logits_indices_padded
|
||||
num_logits_indices = common_attn_metadata.num_logits_indices
|
||||
# Get rid of CUDAGraph padding, if any
|
||||
logits_indices = logits_indices_padded[:num_logits_indices]
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
# Example inputs
|
||||
# num_reqs: 3
|
||||
# generation_indices: [14, 18, 19, 27]
|
||||
# query_start_loc: [0, 15, 20, 28]
|
||||
# seq_lens: [41, 31, 40]
|
||||
|
||||
# Find how many decode indices belong to each request
|
||||
# request_ids: [0, 1, 1, 2]
|
||||
request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
|
||||
|
||||
# Figure out how many tokens are in each request
|
||||
# num_decode_tokens: [1, 2, 1]
|
||||
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
|
||||
|
||||
# Calculate new query_start_loc with tokens in generation_indices
|
||||
# decode_query_start_loc: [0, 1, 3, 4]
|
||||
decode_query_start_loc = torch.empty(
|
||||
num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype
|
||||
)
|
||||
|
||||
decode_query_start_loc[0] = 0
|
||||
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
|
||||
decode_max_query_len = int(num_decode_tokens.max().item())
|
||||
total_num_decode_tokens = int(num_decode_tokens.sum().item())
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=decode_query_start_loc,
|
||||
query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_decode_tokens,
|
||||
max_query_len=decode_max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
)
|
||||
return common_attn_metadata
|
||||
|
||||
|
||||
def split_decodes_prefills_and_extends(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
) -> tuple[int, int, int, int, int, int]:
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: CommonAttentionMetadata object containing the
|
||||
batch metadata.
|
||||
decode_threshold: The maximum query length to be considered a decode.
|
||||
|
||||
Returns:
|
||||
num_decodes: The number of decode requests.
|
||||
num_extends: The number of extend requests.
|
||||
num_prefills: The number of prefill requests.
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_extend_tokens: The number of tokens in the extend requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, 0, num_tokens, 0, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
is_prefill_or_extend = query_lens > decode_threshold
|
||||
is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
|
||||
first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
num_decodes = first_extend
|
||||
num_decode_tokens = query_start_loc[first_extend].item()
|
||||
if not torch.any(is_prefill_or_extend):
|
||||
return (num_decodes, 0, 0, num_decode_tokens, 0, 0)
|
||||
|
||||
num_prefills_or_extends = num_reqs - num_decodes
|
||||
num_prefill_or_extend_tokens = num_tokens - num_decode_tokens
|
||||
if not torch.any(is_prefill):
|
||||
return (
|
||||
num_decodes,
|
||||
num_prefills_or_extends,
|
||||
0,
|
||||
num_decode_tokens,
|
||||
num_prefill_or_extend_tokens,
|
||||
0,
|
||||
)
|
||||
|
||||
num_extends = first_prefill - num_decodes
|
||||
num_prefills = num_reqs - first_prefill
|
||||
|
||||
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
|
||||
num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens
|
||||
return (
|
||||
num_decodes,
|
||||
num_extends,
|
||||
num_prefills,
|
||||
num_decode_tokens,
|
||||
num_extend_tokens,
|
||||
num_prefill_tokens,
|
||||
)
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
require_uniform: bool = False,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: CommonAttentionMetadata object containing the
|
||||
batch metadata.
|
||||
decode_threshold: The maximum query length to be considered a decode.
|
||||
require_uniform: If True, requires that all decode requests have the
|
||||
same query length. When set, some queries may be considered prefills
|
||||
even if they are <= decode_threshold, in order to ensure uniformity.
|
||||
|
||||
Returns:
|
||||
num_decodes: The number of decode requests.
|
||||
num_prefills: The number of prefill requests.
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if max_query_len <= decode_threshold and (
|
||||
not require_uniform or decode_threshold <= 1
|
||||
):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
if query_lens[0].item() > decode_threshold:
|
||||
# first request is not decode, so no decode requests
|
||||
return 0, num_reqs, 0, num_tokens
|
||||
|
||||
if require_uniform:
|
||||
# check if we are in a padded uniform batch; this is used for full-CGs, some
|
||||
# requests may have a query length of 0 but since they are padding its fine
|
||||
# to treat them as decodes (ensures num_decodes matches the captured size)
|
||||
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
|
||||
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
|
||||
return num_reqs, 0, num_tokens, 0 # all decodes
|
||||
is_prefill = query_lens != query_lens[0]
|
||||
else:
|
||||
is_prefill = query_lens > decode_threshold
|
||||
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
|
||||
|
||||
def split_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Split the prefill requests into chunks such that the total sequence length
|
||||
of each chunk is less than or equal to the workspace size.
|
||||
|
||||
Args:
|
||||
seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
|
||||
workspace_size: The maximum workspace size (in tokens) per chunk.
|
||||
request_offset: The offset to add to the request indices.
|
||||
Returns:
|
||||
A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
|
||||
"""
|
||||
chunk_bounds = []
|
||||
i, n = 0, len(seq_lens_cpu)
|
||||
assert torch.all(seq_lens_cpu <= workspace_size).item()
|
||||
|
||||
while i < n:
|
||||
start, chunk_total = i, 0
|
||||
while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
|
||||
chunk_total += s
|
||||
i += 1
|
||||
chunk_bounds.append((start + request_offset, i + request_offset))
|
||||
return chunk_bounds
|
||||
|
||||
|
||||
def reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput",
|
||||
decode_threshold: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Reorders the batch to split into prefill and decode requests; places all
|
||||
requests with <= decode_threshold tokens at the front of the batch.
|
||||
|
||||
Returns:
|
||||
True if the batch was modified, False otherwise.
|
||||
"""
|
||||
# We now want to reorder the batch into decode → extend → prefill order
|
||||
# where:
|
||||
# decode: request with num_scheduled_tokens <= decode_threshold
|
||||
# extend: non-decode request with existing context
|
||||
# prefill: non-decode request with no existing context
|
||||
# NOTE for now we loosely use "decode" to mean requests where attention is
|
||||
# likely memory-bound and "prefill" to mean requests where attention is
|
||||
# likely compute-bound,
|
||||
num_reqs = len(input_batch.req_ids)
|
||||
num_scheduled_tokens = [
|
||||
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
|
||||
]
|
||||
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
|
||||
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
|
||||
is_prefill = num_computed_tokens_np == 0
|
||||
is_decode = (num_scheduled_tokens_np <= decode_threshold) & (~is_prefill)
|
||||
is_extend = (num_scheduled_tokens_np > decode_threshold) & (~is_prefill)
|
||||
|
||||
# Desired order: decode → extend → prefill
|
||||
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
|
||||
req_regions[is_extend] = 1
|
||||
req_regions[is_prefill] = 2
|
||||
|
||||
num_decodes = int(is_decode.sum())
|
||||
num_extends = int(is_extend.sum())
|
||||
|
||||
target_regions = np.zeros(num_reqs, dtype=np.int32)
|
||||
target_regions[num_decodes : num_decodes + num_extends] = 1
|
||||
target_regions[num_decodes + num_extends :] = 2
|
||||
|
||||
needs_swap = req_regions != target_regions
|
||||
|
||||
if not needs_swap.any():
|
||||
return False
|
||||
|
||||
# Extract indices that need swapping and sort by target region
|
||||
orig_indices = np.where(needs_swap)[0]
|
||||
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
|
||||
src_indices = orig_indices[sorted_order]
|
||||
|
||||
src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
|
||||
|
||||
for src in src_dest_map:
|
||||
dst = src_dest_map[src]
|
||||
while src != dst:
|
||||
input_batch.swap_states(src, dst)
|
||||
# Mark dst as done by updating its destination to itself
|
||||
next_dst = src_dest_map.get(dst, dst)
|
||||
src_dest_map[dst] = dst
|
||||
dst = next_dst
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Reshapes the query tensor for the specified batch size, so that
|
||||
it has shape (batch_size, seq_len, num_heads, head_dim).
|
||||
"""
|
||||
assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
|
||||
total_tokens = query.shape[0]
|
||||
num_heads = query.shape[1]
|
||||
head_dim = query.shape[2]
|
||||
assert total_tokens % batch_size == 0, (
|
||||
f"{total_tokens=} is not divisible by {batch_size=}"
|
||||
)
|
||||
seq_len = total_tokens // batch_size
|
||||
return query.view(batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
|
||||
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reshapes the attention output tensor, so that
|
||||
the batch_size and seq_len dimensions are combined.
|
||||
"""
|
||||
if attn_output.dim() == 3:
|
||||
# Already in the correct shape
|
||||
return attn_output
|
||||
assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D"
|
||||
total_tokens = attn_output.shape[0] * attn_output.shape[1]
|
||||
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
|
||||
|
||||
|
||||
def subclass_attention_metadata(
|
||||
name_prefix: str,
|
||||
metadata_cls: Any,
|
||||
fields: list[tuple[str, Any, Any]],
|
||||
) -> Any:
|
||||
"""
|
||||
Return a new subclass of `metadata_cls` with additional fields
|
||||
"""
|
||||
name: str = name_prefix + metadata_cls.__name__ # type: ignore
|
||||
Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
|
||||
return Wrapped
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KVSharingFastPrefillMetadata(Protocol):
|
||||
logits_indices_padded: torch.Tensor | None = None
|
||||
num_logits_indices: int | None = None
|
||||
|
||||
|
||||
def create_fast_prefill_custom_backend(
|
||||
prefix: str,
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
) -> type[AttentionBackend]:
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_common_attn_metadata = (
|
||||
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
|
||||
)
|
||||
metadata = super().build(
|
||||
common_prefix_len, new_common_attn_metadata, fast_build
|
||||
)
|
||||
|
||||
class KVSharingFastPrefillAttentionMetadata(
|
||||
metadata.__class__, # type: ignore
|
||||
KVSharingFastPrefillMetadata,
|
||||
):
|
||||
def __init__(self, metadata, common_attn_metadata):
|
||||
# Shallow copy all fields in metadata cls
|
||||
for _field in fields(metadata.__class__):
|
||||
setattr(self, _field.name, getattr(metadata, _field.name))
|
||||
|
||||
self.logits_indices_padded = (
|
||||
common_attn_metadata.logits_indices_padded
|
||||
)
|
||||
self.num_logits_indices = common_attn_metadata.num_logits_indices
|
||||
|
||||
return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=FastPrefillAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
def compute_causal_conv1d_metadata(
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
*,
|
||||
device: torch.device,
|
||||
):
|
||||
# Needed for causal_conv1d. Use the CPU query_start_loc to avoid DtoH sync.
|
||||
assert query_start_loc_p_cpu.device.type == "cpu"
|
||||
seqlens = query_start_loc_p_cpu.diff()
|
||||
nums_dict = {} # type: ignore
|
||||
batch_ptr = None
|
||||
token_chunk_offset_ptr = None
|
||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||
nums = -(-seqlens // BLOCK_M)
|
||||
nums_dict[BLOCK_M] = {}
|
||||
nums_dict[BLOCK_M]["nums"] = nums
|
||||
nums_dict[BLOCK_M]["tot"] = nums.sum().item()
|
||||
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
||||
nums_dict[BLOCK_M]["mlist"] = mlist
|
||||
mlist_len = len(nums_dict[BLOCK_M]["mlist"])
|
||||
nums_dict[BLOCK_M]["mlist_len"] = mlist_len
|
||||
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
||||
offsetlist = [] # type: ignore
|
||||
for idx, num in enumerate(nums):
|
||||
offsetlist.extend(range(num))
|
||||
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
||||
nums_dict[BLOCK_M]["offsetlist"] = offsetlist
|
||||
|
||||
if batch_ptr is None:
|
||||
# Update default value after class definition
|
||||
batch_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
token_chunk_offset_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
else:
|
||||
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||
token_chunk_offset_ptr.resize_( # type: ignore
|
||||
MAX_NUM_PROGRAMS
|
||||
).fill_(PAD_SLOT_ID)
|
||||
|
||||
batch_ptr[0:mlist_len].copy_(mlist, non_blocking=True)
|
||||
token_chunk_offset_ptr[ # type: ignore
|
||||
0:mlist_len
|
||||
].copy_(offsetlist, non_blocking=True)
|
||||
nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
|
||||
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
|
||||
|
||||
return nums_dict, batch_ptr, token_chunk_offset_ptr
|
||||
|
||||
|
||||
def get_dcp_local_seq_lens(
|
||||
seq_lens: torch.Tensor,
|
||||
dcp_size: int = 1,
|
||||
dcp_rank: int | None = None,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""While using dcp, kv_cache size stored on each rank may be different,
|
||||
use this function to calculate split decode seq_lens of each dcp rank.
|
||||
Only consider dcp now, we can extend the case of cp based on this.
|
||||
"""
|
||||
num_requests = seq_lens.size(0)
|
||||
if dcp_rank is None:
|
||||
rank_offsets = (
|
||||
torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(num_requests, 1)
|
||||
)
|
||||
else:
|
||||
rank_offsets = torch.tensor(
|
||||
[[dcp_rank]], dtype=torch.int32, device=seq_lens.device
|
||||
)
|
||||
seq_lens_tiled = (
|
||||
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
|
||||
)
|
||||
base = (
|
||||
seq_lens_tiled
|
||||
// cp_kv_cache_interleave_size
|
||||
// dcp_size
|
||||
* cp_kv_cache_interleave_size
|
||||
)
|
||||
remainder = seq_lens_tiled - base * dcp_size
|
||||
remainder = torch.clip(
|
||||
remainder - rank_offsets * cp_kv_cache_interleave_size,
|
||||
0,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_local_seq_lens = base + remainder
|
||||
return dcp_local_seq_lens.squeeze(1)
|
||||
|
||||
|
||||
def mamba_get_block_table_tensor(
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
mamba_cache_mode: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get the block table tensor for mamba kernels from the input
|
||||
common_attn_metadata.block_table_tensor given different mamba cache modes.
|
||||
|
||||
- "all": input (#requests, cdiv(max_model_len, block_size));
|
||||
output (#requests, cdiv(max_model_len, block_size)).
|
||||
|
||||
- "none": input (#requests, 1 + num_speculative_blocks);
|
||||
output (#requests, 1 + num_speculative_blocks).
|
||||
|
||||
- "align": input (#requests, cdiv(max_model_len, block_size));
|
||||
output (#requests, 1 + num_speculative_blocks), which are the last
|
||||
1 + num_speculative_blocks of each request.
|
||||
"""
|
||||
if mamba_cache_mode in ("all", "none"):
|
||||
return block_table
|
||||
else:
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
# NOTE: For 0-length requests in CUDA graph, use a start_index of 0
|
||||
# to handle the invalid block table.
|
||||
start_indices = torch.clamp(
|
||||
(seq_lens - 1) // kv_cache_spec.block_size,
|
||||
min=0,
|
||||
)
|
||||
# Use int32 for arithmetic to avoid dtype promotion overhead,
|
||||
# then convert to int64 for gather (which requires Long indices)
|
||||
offsets = torch.arange(
|
||||
1 + kv_cache_spec.num_speculative_blocks,
|
||||
device=block_table.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
indices_to_gather = (start_indices.unsqueeze(1) + offsets).to(torch.int64)
|
||||
return torch.gather(block_table, 1, indices_to_gather)
|
||||
Reference in New Issue
Block a user