Sync from v0.13
This commit is contained in:
654
tests/v1/attention/test_attention_backends.py
Normal file
654
tests/v1/attention/test_attention_backends.py
Normal file
@@ -0,0 +1,654 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 attention backends without GPUModelRunner dependency."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
set_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.TREE_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
# Remove flashinfer from the list if it's not available
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
except ImportError:
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER)
|
||||
|
||||
|
||||
def _convert_dtype_to_torch(dtype):
|
||||
"""Convert ModelDType to torch.dtype."""
|
||||
if isinstance(dtype, str):
|
||||
if dtype == "auto":
|
||||
return torch.float16 # Default dtype for testing
|
||||
elif dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||||
return STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
|
||||
# Define common batch configurations
|
||||
BATCH_SPECS = {
|
||||
"small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode": BatchSpec(
|
||||
seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
|
||||
),
|
||||
"medium_prefill": BatchSpec(
|
||||
seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
|
||||
),
|
||||
"mixed_medium": BatchSpec(
|
||||
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
|
||||
),
|
||||
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"mixed_large": BatchSpec(
|
||||
seq_lens=[1024, 2048, 4096, 1024, 2048, 4096], query_lens=[1, 1, 1, 32, 32, 32]
|
||||
),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
}
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
k_contexts: list[torch.Tensor],
|
||||
v_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Create and prepopulate a KV cache with context data.
|
||||
|
||||
Args:
|
||||
k_contexts: List of key context tensors for each sequence
|
||||
v_contexts: List of value context tensors for each sequence
|
||||
seq_lens: List of sequence lengths
|
||||
block_size: Size of each block
|
||||
num_kv_heads: Number of KV heads
|
||||
head_size: Size of each head
|
||||
dtype: Data type for the cache
|
||||
device: Device to create the cache on
|
||||
num_blocks: Total number of blocks in the cache
|
||||
block_table: Block table tensor to populate
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
or use sequential order
|
||||
|
||||
Returns:
|
||||
Tuple of (kv_cache, updated_block_table)
|
||||
"""
|
||||
batch_size = len(k_contexts)
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
query_lens = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
context_lens = common_attn_metadata.num_computed_tokens_cpu
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
# Create KV cache
|
||||
kv_cache = torch.empty(
|
||||
2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
|
||||
)
|
||||
kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
# Start from block_id=1 since block_id=0 is considered the null block
|
||||
start_block_idx = 1
|
||||
for i in range(batch_size):
|
||||
k_context, v_context = k_contexts[i], v_contexts[i]
|
||||
start = start_block_idx * block_size
|
||||
end = start + k_context.shape[0]
|
||||
kv_cache_flat[0, start:end, ...] = k_context
|
||||
kv_cache_flat[1, start:end, ...] = v_context
|
||||
|
||||
# Stay block aligned and allocate enough blocks for the new tokens
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
|
||||
blocks_end = start_block_idx
|
||||
|
||||
# Permute the context blocks (excluding block 0 which is null)
|
||||
if randomize_blocks:
|
||||
# Random permutation starting from block 1
|
||||
perm = torch.randperm(blocks_end - 1) + 1
|
||||
else:
|
||||
# Sequential order starting from block 1
|
||||
perm = torch.arange(1, blocks_end)
|
||||
|
||||
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
|
||||
# Add 1 to account for starting from block 1
|
||||
inv_perm[1:] = torch.argsort(perm) + 1
|
||||
kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]
|
||||
|
||||
# Construct the right block table
|
||||
# Start from block_id=1 since block_id=0 is considered the null block
|
||||
start_block_idx = 1
|
||||
for i in range(batch_size):
|
||||
num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size)
|
||||
start = start_block_idx
|
||||
end = start + num_blocks_for_seq
|
||||
block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
|
||||
start_block_idx += num_blocks_for_seq
|
||||
|
||||
# Create a realistic slot mapping that corresponds to the block table
|
||||
for i in range(batch_size):
|
||||
token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i])
|
||||
block_indices = token_offsets // block_size
|
||||
token_inter_block_offsets = token_offsets % block_size
|
||||
start = common_attn_metadata.query_start_loc_cpu[i]
|
||||
end = common_attn_metadata.query_start_loc_cpu[i + 1]
|
||||
slot_mapping[start:end] = block_table[
|
||||
i, block_indices
|
||||
] * block_size + token_inter_block_offsets.to(device)
|
||||
|
||||
return kv_cache
|
||||
|
||||
|
||||
class MockAttentionLayer:
|
||||
"""A mock attention layer for testing."""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self._q_scale = torch.tensor(1.0, device=device)
|
||||
self._k_scale = torch.tensor(1.0, device=device)
|
||||
self._v_scale = torch.tensor(1.0, device=device)
|
||||
# Add float versions for flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
|
||||
def run_attention_backend(
|
||||
backend: AttentionBackendEnum,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
sliding_window: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
|
||||
# Handle special case for FLEX_ATTENTION_SLOW
|
||||
actual_backend = backend
|
||||
|
||||
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
|
||||
if backend == "FLEX_ATTENTION_SLOW":
|
||||
actual_backend = AttentionBackendEnum.FLEX_ATTENTION
|
||||
use_direct_block_mask = False
|
||||
|
||||
builder_cls, impl_cls = try_get_attention_backend(actual_backend)
|
||||
|
||||
# Mock flashinfer's get_per_layer_parameters if needed
|
||||
if actual_backend == AttentionBackendEnum.FLASHINFER:
|
||||
import unittest.mock
|
||||
|
||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||
|
||||
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
|
||||
# Return mock parameters for a single layer
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
return {
|
||||
layer_name: PerLayerParameters(
|
||||
window_left=-1, # No sliding window
|
||||
logits_soft_cap=0.0, # No soft cap
|
||||
sm_scale=1.0 / (head_size**0.5), # Standard scale
|
||||
)
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with unittest.mock.patch(
|
||||
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
|
||||
mock_get_per_layer_parameters,
|
||||
):
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
else:
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
if actual_backend == AttentionBackendEnum.FLEX_ATTENTION:
|
||||
builder.direct_build = use_direct_block_mask
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
# Instantiate implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=sliding_window,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
|
||||
# Create mock layer and output buffer
|
||||
mock_layer = MockAttentionLayer(device)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
# Run forward pass
|
||||
# NOTE: The query, key, and value are already shaped correctly
|
||||
# in the calling test function.
|
||||
output = impl.forward(
|
||||
mock_layer, query, key, value, kv_cache, attn_metadata, output=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _test_backend_correctness(
|
||||
batch_spec: BatchSpec,
|
||||
model: str,
|
||||
backend_to_test: list[AttentionBackendEnum | str],
|
||||
mask_mod,
|
||||
*,
|
||||
block_size: int = 16,
|
||||
atol: float = 1e-2,
|
||||
rtol: float = 1e-2,
|
||||
tensor_parallel_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Test that all backends produce similar outputs to a reference implementation
|
||||
using torch.nn.functional.scaled_dot_product_attention.
|
||||
|
||||
This test works by:
|
||||
1. Generating a batch of sequences with specified context and query lengths.
|
||||
2. Computing a ground-truth attention output using torch.sdpa on
|
||||
contiguous Q, K, and V tensors.
|
||||
3. Simulating vLLM's paged KV cache: It takes the context portion of the
|
||||
K/V tensors and manually places them into a paged buffer according to
|
||||
the test's (randomly generated) block table.
|
||||
4. Running each vLLM attention backend with the new queries and the
|
||||
simulated paged KV cache.
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
|
||||
Note: When tensor_parallel_size > 1, we simulate the head partitioning
|
||||
by overriding the model config to use fewer heads, without requiring
|
||||
multiple GPUs. This tests that backends work correctly with different
|
||||
head counts.
|
||||
"""
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
hf_config_override = None
|
||||
if tensor_parallel_size > 1:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
temp_config = ModelConfig(model=model, max_model_len=1)
|
||||
original_num_heads = temp_config.hf_text_config.num_attention_heads
|
||||
original_num_kv_heads = getattr(
|
||||
temp_config.hf_text_config, "num_key_value_heads", None
|
||||
)
|
||||
hf_config_override = {
|
||||
"num_attention_heads": original_num_heads // tensor_parallel_size,
|
||||
}
|
||||
if original_num_kv_heads is not None:
|
||||
hf_config_override["num_key_value_heads"] = max(
|
||||
1, original_num_kv_heads // tensor_parallel_size
|
||||
)
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model,
|
||||
tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=8192,
|
||||
hf_config_override=hf_config_override,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
# 1. Setup
|
||||
batch_size = batch_spec.batch_size
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
sliding_window = vllm_config.model_config.get_sliding_window()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
|
||||
# 2. Generate data and compute SDPA reference output
|
||||
all_q_vllm, all_k_vllm, all_v_vllm = [], [], []
|
||||
all_sdpa_outputs = []
|
||||
k_contexts, v_contexts = [], []
|
||||
|
||||
for i in range(batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
context_len = s_len - q_len
|
||||
|
||||
# Generate Q, K, V for the whole sequence to be used in SDPA
|
||||
q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device)
|
||||
k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
|
||||
v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
|
||||
|
||||
# SDPA expects (N, H, L, D), so unsqueeze batch and permute
|
||||
q_sdpa_in = q.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
if num_q_heads != num_kv_heads:
|
||||
assert num_q_heads % num_kv_heads == 0, (
|
||||
f"num_q_heads ({num_q_heads}) must be divisible by "
|
||||
f"num_kv_heads ({num_kv_heads})"
|
||||
)
|
||||
repeats = num_q_heads // num_kv_heads
|
||||
k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1)
|
||||
v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1)
|
||||
|
||||
# Create causal mask: query token i attends to positions 0 to
|
||||
# (context_len + i)
|
||||
kv_len = s_len
|
||||
|
||||
final_mask_mod = partial(mask_mod, context_len=context_len)
|
||||
block_mask = create_block_mask(
|
||||
final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device
|
||||
)
|
||||
sdpa_out_i = flex_attention(
|
||||
q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
|
||||
|
||||
# Inputs for vLLM backends are just the new tokens
|
||||
all_q_vllm.append(q)
|
||||
all_k_vllm.append(k_full[context_len:])
|
||||
all_v_vllm.append(v_full[context_len:])
|
||||
|
||||
# Contextual K/V data used to populate the paged cache
|
||||
k_contexts.append(k_full[:context_len])
|
||||
v_contexts.append(v_full[:context_len])
|
||||
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
key_vllm = torch.cat(all_k_vllm, dim=0)
|
||||
value_vllm = torch.cat(all_v_vllm, dim=0)
|
||||
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
|
||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
k_contexts=k_contexts,
|
||||
v_contexts=v_contexts,
|
||||
block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=True,
|
||||
)
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
# Note: flex_attention has known Triton kernel compatibility issues
|
||||
# with test infrastructures
|
||||
for backend_name in backend_to_test:
|
||||
# FlashAttentionm + FlexAttention:
|
||||
# [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
# FlashInfer + Triton:
|
||||
# [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
# Select the appropriate KV cache format for each backend
|
||||
kv_cache_for_backend = kv_cache
|
||||
reset_kv_cache_layout = False
|
||||
if backend_name in (
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
):
|
||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||
|
||||
if backend_name == AttentionBackendEnum.FLASHINFER:
|
||||
# For FlashInfer default to HND layout and
|
||||
kv_cache_for_backend = (
|
||||
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
|
||||
)
|
||||
set_kv_cache_layout("HND")
|
||||
reset_kv_cache_layout = True
|
||||
elif backend_name == AttentionBackendEnum.TRITON_ATTN:
|
||||
kv_cache_for_backend = kv_cache_for_backend.contiguous()
|
||||
|
||||
try:
|
||||
backend_output = run_attention_backend(
|
||||
backend_name,
|
||||
kv_cache_spec,
|
||||
["placeholder"],
|
||||
vllm_config,
|
||||
device,
|
||||
common_attn_metadata,
|
||||
query_vllm,
|
||||
key_vllm,
|
||||
value_vllm,
|
||||
kv_cache_for_backend,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
finally:
|
||||
if reset_kv_cache_layout:
|
||||
set_kv_cache_layout(None)
|
||||
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_output.shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {sdpa_output.shape}"
|
||||
)
|
||||
assert backend_output.dtype == sdpa_output.dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {sdpa_output.dtype}"
|
||||
)
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values"
|
||||
)
|
||||
|
||||
# Check numerical similarity
|
||||
def error_msg(msg: str, backend_name: str):
|
||||
return f"[{backend_name}] output differs from SDPA baseline. {msg}"
|
||||
|
||||
torch.testing.assert_close(
|
||||
backend_output,
|
||||
sdpa_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=partial(error_msg, backend_name=backend_name),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_decode",
|
||||
"small_prefill",
|
||||
"mixed_small",
|
||||
"medium_decode",
|
||||
"medium_prefill",
|
||||
"mixed_medium",
|
||||
"large_decode",
|
||||
"large_prefill",
|
||||
"single_decode",
|
||||
"single_prefill",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
def test_causal_backend_correctness(
|
||||
batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""Test backend's correctness with causal attention."""
|
||||
|
||||
def causal_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
*,
|
||||
context_len: int,
|
||||
):
|
||||
return (q_idx + context_len) >= kv_idx
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
LARGE_BLOCK_BACKENDS = (
|
||||
[AttentionBackendEnum.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0")
|
||||
else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
SMALL_BLOCK_BACKENDS,
|
||||
causal_mask_mod,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
causal_mask_mod,
|
||||
block_size=128,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_decode",
|
||||
"small_prefill",
|
||||
"mixed_medium",
|
||||
"large_decode",
|
||||
"large_prefill",
|
||||
"mixed_large",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
def test_sliding_window_backend_correctness(
|
||||
batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""Test backend's correctness with sliding window attention."""
|
||||
|
||||
def sliding_window_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
*,
|
||||
context_len: int,
|
||||
sliding_window: int,
|
||||
):
|
||||
causal_mask = q_idx + context_len >= kv_idx
|
||||
window_mask = q_idx + context_len - kv_idx < sliding_window
|
||||
return causal_mask & window_mask
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
sliding_window_mask_mod_fn = partial(
|
||||
sliding_window_mask_mod, sliding_window=sliding_window
|
||||
)
|
||||
|
||||
LARGE_BLOCK_BACKENDS = (
|
||||
[AttentionBackendEnum.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0")
|
||||
else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
SMALL_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
block_size=128,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
111
tests/v1/attention/test_attention_backends_selection.py
Normal file
111
tests/v1/attention/test_attention_backends_selection.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for mamba attention backend selectors."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.short_conv import ShortConv
|
||||
from vllm.model_executor.models.minimax_text_01 import MiniMaxText01LinearAttention
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_class, init_kwargs, expected_backend, expected_mamba_type",
|
||||
[
|
||||
(
|
||||
MambaMixer,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
ssm_state_size=16,
|
||||
conv_kernel_size=4,
|
||||
intermediate_size=256,
|
||||
time_step_rank=8,
|
||||
use_conv_bias=True,
|
||||
use_bias=False,
|
||||
use_rms_norm=True,
|
||||
),
|
||||
Mamba1AttentionBackend,
|
||||
"mamba1",
|
||||
),
|
||||
(
|
||||
MambaMixer2,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
ssm_state_size=16,
|
||||
conv_kernel_size=4,
|
||||
intermediate_size=256,
|
||||
use_conv_bias=True,
|
||||
use_bias=False,
|
||||
n_groups=1,
|
||||
num_heads=8,
|
||||
head_dim=32,
|
||||
),
|
||||
Mamba2AttentionBackend,
|
||||
"mamba2",
|
||||
),
|
||||
(
|
||||
MiniMaxText01LinearAttention,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
hidden_inner_size=256,
|
||||
num_heads=8,
|
||||
head_dim=32,
|
||||
max_position=2048,
|
||||
block_size=64,
|
||||
num_hidden_layer=12,
|
||||
layer_idx=0,
|
||||
linear_layer_idx=0,
|
||||
),
|
||||
LinearAttentionBackend,
|
||||
"linear_attention",
|
||||
),
|
||||
(
|
||||
ShortConv,
|
||||
dict(
|
||||
config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
|
||||
dim=128,
|
||||
layer_idx=0,
|
||||
),
|
||||
ShortConvAttentionBackend,
|
||||
"short_conv",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mamba_layers_get_attn_backend(
|
||||
dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type
|
||||
):
|
||||
"""Test that Mamba-like layers return the correct attention backend."""
|
||||
layer = layer_class(**init_kwargs)
|
||||
|
||||
backend_class = layer.get_attn_backend()
|
||||
assert backend_class is expected_backend
|
||||
assert layer.mamba_type == expected_mamba_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_class,expected_backend,expected_mamba_type",
|
||||
[
|
||||
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
|
||||
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
|
||||
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
|
||||
(ShortConv, ShortConvAttentionBackend, "short_conv"),
|
||||
],
|
||||
)
|
||||
def test_mamba_layers_have_unified_interface(
|
||||
layer_class, expected_backend, expected_mamba_type
|
||||
):
|
||||
"""Test that all Mamba layers have the unified get_attn_backend
|
||||
interface."""
|
||||
assert hasattr(layer_class, "get_attn_backend"), (
|
||||
f"{layer_class.__name__} should have get_attn_backend method"
|
||||
)
|
||||
assert hasattr(layer_class, "mamba_type"), (
|
||||
f"{layer_class.__name__} should have mamba_type property"
|
||||
)
|
||||
378
tests/v1/attention/test_attention_splitting.py
Normal file
378
tests/v1/attention/test_attention_splitting.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_query_start_loc():
|
||||
"""Sample query_start_loc tensor for testing"""
|
||||
return torch.tensor([0, 5, 12, 20, 35, 50])
|
||||
|
||||
|
||||
def test_basic_slice_middle(sample_query_start_loc):
|
||||
"""Test slicing from middle of tensor"""
|
||||
req_slice = slice(1, 3) # slice from index 1 to 3
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 7, 15])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_slice_from_beginning(sample_query_start_loc):
|
||||
"""Test slicing from the beginning of tensor"""
|
||||
req_slice = slice(0, 2) # slice from index 0 to 2
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 5, 12])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_slice_to_end(sample_query_start_loc):
|
||||
"""Test slicing to the end of tensor"""
|
||||
req_slice = slice(3, 5) # slice from index 3 to 5 (last index)
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 15, 30])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_single_element_slice(sample_query_start_loc):
|
||||
"""Test slice that results in single element"""
|
||||
req_slice = slice(2, 3) # slice from index 2 to 3
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 8])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_full_tensor_slice(sample_query_start_loc):
|
||||
"""Test slicing the entire tensor"""
|
||||
req_slice = slice(0, 5) # slice entire tensor
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 5, 12, 20, 35, 50])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_slice_bounds_edge_cases(sample_query_start_loc):
|
||||
# Test slice that goes exactly to the last element
|
||||
req_slice = slice(4, 5) # Last index
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 15])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_decode_metadata():
|
||||
"""Create metadata for small decode batch"""
|
||||
batch_spec = BATCH_SPECS["small_decode"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def large_decode_metadata():
|
||||
"""Create metadata for small decode batch"""
|
||||
batch_spec = BATCH_SPECS["large_decode"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_small_metadata():
|
||||
"""Create metadata for mixed small batch"""
|
||||
batch_spec = BATCH_SPECS["mixed_small"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
# Tests for _make_metadata_with_slice
|
||||
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
||||
"""Test slicing decode batch metadata"""
|
||||
# Split first request only
|
||||
ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1))
|
||||
|
||||
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
|
||||
|
||||
# Check sliced results
|
||||
assert result.num_reqs == 1 # slice(0, 1) gives 1 requests
|
||||
assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token
|
||||
assert result.max_query_len == 1
|
||||
assert torch.equal(result.query_start_loc, torch.tensor([0, 1]))
|
||||
assert torch.equal(result.seq_lens, torch.tensor([32]))
|
||||
|
||||
|
||||
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
|
||||
"""Test slicing mixed batch metadata"""
|
||||
ubatch_slice = UBatchSlice(slice(1, 3), slice(1, 7)) # Requests 1-3, tokens 1-7
|
||||
|
||||
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
|
||||
|
||||
assert result.num_reqs == 2 # slice(1, 3) gives 2 requests
|
||||
assert result.num_actual_tokens == 6 # slice(1, 7) gives 6 tokens
|
||||
assert result.max_query_len == 5
|
||||
assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6]))
|
||||
assert torch.equal(result.seq_lens, torch.tensor([40, 48]))
|
||||
|
||||
|
||||
def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
"""Test splitting decode batch into two equal parts"""
|
||||
num_tokens = large_decode_metadata.num_reqs
|
||||
mid_point = num_tokens // 2
|
||||
ubatch_slices = [
|
||||
UBatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
||||
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point, num_tokens)),
|
||||
]
|
||||
|
||||
results = split_attn_metadata(ubatch_slices, large_decode_metadata)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
# Check first split
|
||||
assert results[0].num_reqs == mid_point
|
||||
assert results[0].num_actual_tokens == mid_point
|
||||
assert torch.equal(results[0].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
# Check second split
|
||||
assert results[1].num_reqs == mid_point
|
||||
assert results[1].num_actual_tokens == mid_point
|
||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
|
||||
def apply_split_decodes_and_prefills(
|
||||
query_lens: list[int],
|
||||
decode_threshold: int,
|
||||
require_uniform: bool,
|
||||
padded_num_tokens: int | None = None,
|
||||
):
|
||||
"""Helper function to apply split_decodes_and_prefills and return
|
||||
the results."""
|
||||
device = torch.device("cpu")
|
||||
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
|
||||
common_metadata = create_common_attn_metadata(
|
||||
BatchSpec(seq_lens=seq_lens, query_lens=query_lens),
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if padded_num_tokens is not None:
|
||||
common_metadata.num_actual_tokens = padded_num_tokens
|
||||
|
||||
return split_decodes_and_prefills(
|
||||
common_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
require_uniform=require_uniform,
|
||||
)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, False)
|
||||
)
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
|
||||
query_lens = [1, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False)
|
||||
)
|
||||
assert num_decodes == 7
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == sum(query_lens)
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False)
|
||||
)
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
assert num_prefill_tokens == sum(query_lens)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
|
||||
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, False)
|
||||
)
|
||||
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
|
||||
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
|
||||
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, True)
|
||||
)
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_short_decodes():
|
||||
query_lens = [2, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True)
|
||||
)
|
||||
assert num_decodes == 2
|
||||
assert num_prefills == 5
|
||||
assert num_decode_tokens == 4
|
||||
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True)
|
||||
)
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
assert num_prefill_tokens == sum(query_lens)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
|
||||
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True)
|
||||
)
|
||||
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
|
||||
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 6 # 2 + 2 + 2
|
||||
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True)
|
||||
)
|
||||
assert num_decodes == 1 # only the first 2 is taken as decode
|
||||
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
|
||||
assert num_decode_tokens == 2 # only the first 2
|
||||
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_padded_batch_all_same():
|
||||
"""uniform batch where all query lengths are identical with 0 length padded reqs."""
|
||||
# All query lengths are 2, with decode_threshold=3 (so 2 <= 3)
|
||||
# This triggers the padded uniform path at line 891
|
||||
query_lens = [2, 2, 2, 0]
|
||||
padded_num_tokens = 8
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens)
|
||||
)
|
||||
# With uniform batch, all requests are treated as decodes
|
||||
assert num_decodes == 4
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == padded_num_tokens
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||
[
|
||||
# Split in the middle of request 1
|
||||
([32, 40], [8, 8], 12, 2, 1),
|
||||
# Split inside the first request
|
||||
([32, 40], [8, 8], 4, 1, 2),
|
||||
],
|
||||
)
|
||||
def test_prefill_split_across_ubatches(
|
||||
seq_lens, query_lens, split_point, expected_first_reqs, expected_second_reqs
|
||||
):
|
||||
"""Test splitting a prefill across ubatches"""
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cpu")
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
|
||||
common = create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
|
||||
qsl_np = common.query_start_loc_cpu.numpy()
|
||||
num_tokens = common.num_actual_tokens
|
||||
|
||||
ubatch_slices, _ = maybe_create_ubatch_slices(
|
||||
True,
|
||||
num_scheduled_tokens,
|
||||
num_tokens,
|
||||
batch_spec.batch_size,
|
||||
split_point=split_point,
|
||||
)
|
||||
assert ubatch_slices is not None and len(ubatch_slices) == 2
|
||||
|
||||
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
|
||||
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
|
||||
|
||||
# Token counts match the split
|
||||
assert first_meta.num_actual_tokens == split_point
|
||||
assert second_meta.num_actual_tokens == num_tokens - split_point
|
||||
|
||||
# Number of requests per ubatch
|
||||
assert first_meta.num_reqs == expected_first_reqs
|
||||
assert second_meta.num_reqs == expected_second_reqs
|
||||
|
||||
# Identify which request is split and how many tokens are in the first chunk
|
||||
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
|
||||
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
|
||||
orig_q_lens = common.query_start_loc_cpu[1:] - common.query_start_loc_cpu[:-1]
|
||||
|
||||
# Check query length continuity: first-chunk + second-chunk == original qlen
|
||||
# First ubatch last request query length
|
||||
qlen_first_last = int(
|
||||
first_meta.query_start_loc_cpu[-1] - first_meta.query_start_loc_cpu[-2]
|
||||
)
|
||||
# Second ubatch first request query length
|
||||
qlen_second_first = int(
|
||||
second_meta.query_start_loc_cpu[1] - second_meta.query_start_loc_cpu[0]
|
||||
)
|
||||
assert qlen_first_last == tokens_in_first_chunk
|
||||
assert qlen_first_last + qlen_second_first == int(orig_q_lens[split_req_idx])
|
||||
|
||||
# Check seq_lens adjustments
|
||||
# Context lengths per original request
|
||||
context_lens = [s - q for s, q in zip(seq_lens, query_lens)]
|
||||
|
||||
# First ubatch: last request's seq_len should be
|
||||
# context + tokens_in_first_chunk
|
||||
expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk
|
||||
assert int(first_meta.seq_lens[-1]) == expected_seqlen
|
||||
|
||||
# For full preceding requests in first ubatch, seq_lens should match
|
||||
# originals
|
||||
for i in range(first_meta.num_reqs - 1):
|
||||
assert int(first_meta.seq_lens[i]) == seq_lens[i]
|
||||
|
||||
# Second ubatch: first request (continuation) seq_len should be full
|
||||
# original
|
||||
assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx]
|
||||
# Any following full requests in second ubatch should match originals
|
||||
for j in range(1, second_meta.num_reqs):
|
||||
# Map to original request index
|
||||
orig_idx = split_req_idx + j
|
||||
assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx]
|
||||
126
tests/v1/attention/test_batch_reordering.py
Normal file
126
tests/v1/attention/test_batch_reordering.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.v1.attention.backends.utils import reorder_batch_to_split_decodes_and_prefills
|
||||
|
||||
|
||||
class MockInputBatch:
|
||||
def __init__(self, req_ids, num_computed_tokens_cpu):
|
||||
self.req_ids = req_ids
|
||||
self.num_computed_tokens_cpu = num_computed_tokens_cpu
|
||||
|
||||
def swap_states(self, i, j):
|
||||
self.req_ids[i], self.req_ids[j] = self.req_ids[j], self.req_ids[i]
|
||||
self.num_computed_tokens_cpu[i], self.num_computed_tokens_cpu[j] = (
|
||||
self.num_computed_tokens_cpu[j],
|
||||
self.num_computed_tokens_cpu[i],
|
||||
)
|
||||
|
||||
|
||||
class MockSchedulerOutput:
|
||||
def __init__(self, num_scheduled_tokens):
|
||||
self.num_scheduled_tokens = num_scheduled_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReorderTestCase:
|
||||
requests: list[tuple[int, int]] # (num_scheduled_tokens, num_computed_tokens)
|
||||
expected_order: list[int]
|
||||
expected_modified: bool
|
||||
decode_threshold: int = 1
|
||||
|
||||
|
||||
# Test cases for batch reordering
|
||||
REORDER_TEST_CASES = {
|
||||
"all_decodes": ReorderTestCase(
|
||||
requests=[(1, 10), (1, 20), (1, 30)],
|
||||
expected_order=[0, 1, 2],
|
||||
expected_modified=False,
|
||||
),
|
||||
"all_prefills": ReorderTestCase(
|
||||
requests=[(100, 100), (200, 200), (300, 300)],
|
||||
expected_order=[0, 1, 2],
|
||||
expected_modified=False,
|
||||
),
|
||||
"mixed_interleaved": ReorderTestCase(
|
||||
requests=[(100, 100), (1, 10), (200, 200), (1, 20)],
|
||||
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
|
||||
expected_modified=True,
|
||||
),
|
||||
"already_ordered": ReorderTestCase(
|
||||
requests=[(1, 10), (1, 20), (100, 100), (200, 0)],
|
||||
expected_order=[0, 1, 2, 3],
|
||||
expected_modified=False,
|
||||
),
|
||||
"single_request": ReorderTestCase(
|
||||
requests=[(1, 10)],
|
||||
expected_order=[0],
|
||||
expected_modified=False,
|
||||
),
|
||||
"higher_threshold": ReorderTestCase(
|
||||
requests=[(2, 10), (3, 20), (5, 30), (6, 40)],
|
||||
expected_order=[0, 1, 2, 3],
|
||||
expected_modified=False,
|
||||
decode_threshold=4,
|
||||
),
|
||||
"decodes_at_end": ReorderTestCase(
|
||||
requests=[(100, 100), (200, 200), (1, 10), (1, 20)],
|
||||
expected_order=[2, 3, 0, 1],
|
||||
expected_modified=True,
|
||||
),
|
||||
"decode_extend_prefill": ReorderTestCase(
|
||||
requests=[(100, 0), (10, 50), (1, 10)],
|
||||
expected_order=[2, 1, 0],
|
||||
expected_modified=True,
|
||||
),
|
||||
"extend_prefill_only": ReorderTestCase(
|
||||
requests=[(100, 0), (10, 50), (200, 0), (20, 75)],
|
||||
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
|
||||
expected_modified=True,
|
||||
),
|
||||
"complicated_mixed_interleaved": ReorderTestCase(
|
||||
requests=[
|
||||
(1, 20),
|
||||
(1, 50),
|
||||
(374, 0),
|
||||
(300, 20),
|
||||
(1, 20),
|
||||
(256, 0),
|
||||
(1, 5),
|
||||
(27, 0),
|
||||
(1, 4),
|
||||
],
|
||||
expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5],
|
||||
expected_modified=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case", REORDER_TEST_CASES.values(), ids=REORDER_TEST_CASES.keys()
|
||||
)
|
||||
def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase):
|
||||
req_ids = [f"r{i}" for i in range(len(test_case.requests))]
|
||||
num_computed_tokens = np.array([r[1] for r in test_case.requests], dtype=np.int32)
|
||||
num_scheduled_tokens = {f"r{i}": r[0] for i, r in enumerate(test_case.requests)}
|
||||
|
||||
input_batch = MockInputBatch(req_ids, num_computed_tokens)
|
||||
scheduler_output = MockSchedulerOutput(num_scheduled_tokens)
|
||||
|
||||
modified = reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch, scheduler_output, decode_threshold=test_case.decode_threshold
|
||||
)
|
||||
|
||||
expected_req_ids = [f"r{i}" for i in test_case.expected_order]
|
||||
|
||||
assert modified == test_case.expected_modified, (
|
||||
f"Expected modified={test_case.expected_modified}, got {modified}"
|
||||
)
|
||||
assert input_batch.req_ids == expected_req_ids, (
|
||||
f"Expected order {expected_req_ids}, got {input_batch.req_ids}"
|
||||
)
|
||||
201
tests/v1/attention/test_chunked_local_attention.py
Normal file
201
tests/v1/attention/test_chunked_local_attention.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import make_local_attention_virtual_batches
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalAttentionTestData:
|
||||
# Input parameters
|
||||
batch_spec: BatchSpec
|
||||
attn_chunk_size: int
|
||||
block_size: int
|
||||
# Expected return values
|
||||
expected_q_seqlens: list[int]
|
||||
expected_k_seqlens: list[int]
|
||||
expected_local_block_table: list[list[int]]
|
||||
|
||||
|
||||
test_data_list = [
|
||||
# Same as example in docstring of make_local_attention_virtual_batches
|
||||
# except block table has 9 columns instead of 10
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[4, 10, 5],
|
||||
seq_lens=[6, 17, 9],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[2, 2, 1, 4, 4, 1, 4, 1],
|
||||
expected_k_seqlens=[4, 2, 4, 4, 4, 1, 4, 1],
|
||||
# 2 pages per local branch
|
||||
# (chunk size 4 // block size 2)
|
||||
expected_local_block_table=[
|
||||
[0, 1], # local-batch 0, (batch 0, starting from k[0])
|
||||
[2, 3], # local-batch 1, (batch 0, starting from k[4])
|
||||
[11, 12], # local-batch 2, (batch 1, starting from k[4])
|
||||
[13, 14], # local-batch 3, (batch 1, starting from k[8])
|
||||
[15, 16], # local-batch 4, (batch 1, starting from k[12])
|
||||
[17, 17], # local-batch 5, (batch 1, starting from k[16])
|
||||
[20, 21], # local-batch 6, (batch 2, starting from k[4])
|
||||
[22, 23], # local-batch 7, (batch 2, starting from k[8])
|
||||
],
|
||||
),
|
||||
# Case where block indices are not clipped to block table ncols-1
|
||||
# because tokens_in_last_block == attn_chunk_size
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[8],
|
||||
seq_lens=[12],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[4, 4],
|
||||
expected_k_seqlens=[4, 4],
|
||||
expected_local_block_table=[
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
],
|
||||
),
|
||||
# Case where all kv_seq positions are involved in attn
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[7],
|
||||
# 10 - 7 = 3 previously computed tokens
|
||||
seq_lens=[10],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[1, 4, 2],
|
||||
expected_k_seqlens=[4, 4, 2],
|
||||
expected_local_block_table=[
|
||||
[0, 1],
|
||||
[2, 3],
|
||||
[4, 4],
|
||||
],
|
||||
),
|
||||
# Case where attn_chunk_size > kv_seq_len
|
||||
# so no extra mini virtual batches are created
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[4],
|
||||
seq_lens=[6],
|
||||
),
|
||||
# Larger than kv_seq_len
|
||||
attn_chunk_size=10,
|
||||
block_size=2,
|
||||
# No change to q_seqlens and k_seqlens
|
||||
expected_q_seqlens=[4],
|
||||
expected_k_seqlens=[6],
|
||||
# In this case, we only need a block-table like:
|
||||
# block_table = [ [0, 1, 2] ] # 1 batch, 3 pages
|
||||
# But we need to pad it to 5 pages per local batch
|
||||
# because currently the pages_per_local_batch
|
||||
# is calculated as (attn_chunk_size // block_size)
|
||||
expected_local_block_table=[
|
||||
[0, 1, 2, 2, 2],
|
||||
],
|
||||
),
|
||||
# Block size equal to chunk size
|
||||
# Expect single page per batch in local batch table
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[6, 6],
|
||||
seq_lens=[8, 8],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=4,
|
||||
expected_q_seqlens=[2, 4, 2, 4],
|
||||
expected_k_seqlens=[4, 4, 4, 4],
|
||||
# Initial block table = [
|
||||
# [0, 1], < batch 0
|
||||
# [2, 3], < batch 1
|
||||
# ]
|
||||
expected_local_block_table=[
|
||||
[0], # local-batch 0, (batch 0, starting from k[0])
|
||||
[1], # local-batch 1, (batch 0, starting from k[4])
|
||||
[2], # local-batch 1, (batch 0, starting from k[0])
|
||||
[3], # local-batch 1, (batch 0, starting from k[4])
|
||||
],
|
||||
),
|
||||
# Case where query falls in the second attention chunk
|
||||
# k_toks > 0 1 2 3 4
|
||||
# q_toks v _____________
|
||||
# 0 | 1
|
||||
# 1 | 1 1
|
||||
# 2 | 1 1 1
|
||||
# 3 | 1 1 1 1
|
||||
# 4 | 1
|
||||
# where tokens 0,1,2,3 have been pre-computed
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[1],
|
||||
seq_lens=[5],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[1],
|
||||
expected_k_seqlens=[1],
|
||||
expected_local_block_table=[
|
||||
[2, 2],
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_data_list)
|
||||
def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
|
||||
device = torch.device("cuda:0")
|
||||
batch_spec = test_data.batch_spec
|
||||
attn_chunk_size = test_data.attn_chunk_size
|
||||
block_size = test_data.block_size
|
||||
expected_q_seqlens = test_data.expected_q_seqlens
|
||||
expected_k_seqlens = test_data.expected_k_seqlens
|
||||
expected_local_block_table = test_data.expected_local_block_table
|
||||
|
||||
# Create common attention metadata
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size,
|
||||
device,
|
||||
# Use torch.arange instead of torch.randint so we can assert on
|
||||
# block table tensor values. The block table will have shape
|
||||
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
|
||||
# arranged from 0 to cdiv(max_seq_len, block_size)-1
|
||||
arange_block_indices=True,
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = make_local_attention_virtual_batches(
|
||||
attn_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
|
||||
# Convert to numpy for easier comparison
|
||||
actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy())
|
||||
actual_k_seqlens = result.seq_lens_cpu.numpy()
|
||||
|
||||
# Check that all query lengths are less than or equal to attn_chunk_size
|
||||
assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens)
|
||||
# Check that all key lengths are less than or equal to attn_chunk_size
|
||||
assert all(k_len <= attn_chunk_size for k_len in actual_k_seqlens)
|
||||
# Check that the total number of query tokens is preserved
|
||||
assert sum(actual_q_seqlens) == sum(batch_spec.query_lens)
|
||||
|
||||
# Verify results
|
||||
np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens)
|
||||
np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens)
|
||||
|
||||
expected_block_table_tensor = torch.tensor(
|
||||
expected_local_block_table, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
print(f"Expected block table:\n{expected_block_table_tensor}")
|
||||
print(f"Actual block table:\n{result.block_table_tensor}")
|
||||
|
||||
torch.testing.assert_close(result.block_table_tensor, expected_block_table_tensor)
|
||||
819
tests/v1/attention/test_mla_backends.py
Normal file
819
tests/v1/attention/test_mla_backends.py
Normal file
@@ -0,0 +1,819 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 MLA backends without GPUModelRunner dependency.
|
||||
|
||||
Known Issues:
|
||||
- FLASH_ATTN_MLA backend occasionally produces NaN values in
|
||||
test_backend_correctness[mixed_small] when run after
|
||||
test_backend_correctness[small_prefill], but passes when run alone.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
]
|
||||
|
||||
# Remove sm100 backends from the list if not using sm100
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA)
|
||||
|
||||
# Remove FLASH_ATTN_MLA from the list if not supported
|
||||
if not flash_attn_supports_mla():
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA)
|
||||
|
||||
# Remove FLASHMLA from the list if not supported
|
||||
if not is_flashmla_dense_supported()[0]:
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
|
||||
|
||||
SPEC_DECODE_BACKENDS = []
|
||||
for backend in BACKENDS_TO_TEST:
|
||||
builder_cls, _ = try_get_attention_backend(backend)
|
||||
query_len_support = getattr(
|
||||
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
|
||||
)
|
||||
if query_len_support != QueryLenSupport.SINGLE_ONLY:
|
||||
SPEC_DECODE_BACKENDS.append(backend)
|
||||
|
||||
BACKEND_BLOCK_SIZES = {}
|
||||
for backend in BACKENDS_TO_TEST:
|
||||
supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
|
||||
if supported_sizes:
|
||||
default_size = supported_sizes[0]
|
||||
block_size = (
|
||||
default_size if isinstance(default_size, int) else default_size.base
|
||||
)
|
||||
else:
|
||||
block_size = 16
|
||||
BACKEND_BLOCK_SIZES[backend] = block_size
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _convert_dtype_to_torch(dtype):
|
||||
"""Convert ModelDType to torch.dtype."""
|
||||
if isinstance(dtype, str):
|
||||
if dtype == "auto":
|
||||
return torch.float16 # Default dtype for testing
|
||||
elif dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||||
return STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
|
||||
# Define common batch configurations
|
||||
BATCH_SPECS = {
|
||||
"small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode": BatchSpec(
|
||||
seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
|
||||
),
|
||||
"medium_prefill": BatchSpec(
|
||||
seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
|
||||
),
|
||||
"mixed_medium": BatchSpec(
|
||||
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
|
||||
),
|
||||
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
"spec_decode_small": BatchSpec(
|
||||
seq_lens=[128, 256, 512, 1024], query_lens=[4, 4, 4, 4]
|
||||
),
|
||||
"spec_decode_medium": BatchSpec(
|
||||
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[8, 8, 8, 8, 8, 8]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts: list[torch.Tensor],
|
||||
k_pe_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
kv_cache_dtype: str | None = None,
|
||||
scale: float | torch.Tensor = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Create and prepopulate an MLA KV cache with context data.
|
||||
|
||||
Args:
|
||||
kv_c_contexts: List of latent KV context tensors for each sequence
|
||||
k_pe_contexts: List of key positional embedding context tensors
|
||||
for each sequence
|
||||
block_size: Size of each block
|
||||
head_size: Size of each head (latent dimension)
|
||||
dtype: Data type for the cache
|
||||
device: Device to create the cache on
|
||||
num_blocks: Total number of blocks in the cache
|
||||
common_attn_metadata: Common attention metadata
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
or use sequential order
|
||||
kv_cache_dtype: Optional kv cache dtype string. When set to
|
||||
"fp8_ds_mla" the cache is populated using the
|
||||
fp8 DeepSeek MLA layout via concat_and_cache_mla.
|
||||
scale: Scaling factor forwarded to concat_and_cache_mla when the
|
||||
fp8 cache layout is requested.
|
||||
|
||||
Returns:
|
||||
MLA KV cache tensor
|
||||
"""
|
||||
batch_size = len(kv_c_contexts)
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
query_lens = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
context_lens = common_attn_metadata.num_computed_tokens_cpu
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
if not kv_c_contexts:
|
||||
raise ValueError(
|
||||
"kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype"
|
||||
)
|
||||
kv_lora_rank = kv_c_contexts[0].shape[-1]
|
||||
rope_dim = k_pe_contexts[0].shape[-1]
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks, block_size, entry_size, dtype=torch.uint8, device=device
|
||||
)
|
||||
scale_tensor = (
|
||||
scale
|
||||
if isinstance(scale, torch.Tensor)
|
||||
else torch.tensor(scale, dtype=torch.float32, device=device)
|
||||
)
|
||||
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks, block_size, head_size, dtype=dtype, device=device
|
||||
)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
# Start from block_id=1 since block_id=0 is considered the null block
|
||||
start_block_idx = 1
|
||||
for i in range(batch_size):
|
||||
kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
|
||||
context_len = kv_c_context.shape[0]
|
||||
if context_len == 0:
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
continue
|
||||
|
||||
start = start_block_idx * block_size
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
slots = torch.arange(context_len, device=device, dtype=torch.long) + start
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_context,
|
||||
k_pe_context.squeeze(1),
|
||||
kv_cache,
|
||||
slots,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale_tensor,
|
||||
)
|
||||
else:
|
||||
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
|
||||
# Stay block aligned and allocate enough blocks for the new tokens
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
|
||||
blocks_end = start_block_idx
|
||||
|
||||
# Permute the context blocks (excluding block 0 which is null)
|
||||
if randomize_blocks:
|
||||
perm = (
|
||||
torch.randperm(blocks_end - 1) + 1
|
||||
) # Random permutation starting from block 1
|
||||
else:
|
||||
perm = torch.arange(1, blocks_end) # Sequential order starting from block 1
|
||||
|
||||
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
|
||||
inv_perm[1:] = torch.argsort(perm) + 1 # Add 1 to account for starting from block 1
|
||||
kv_cache[1:blocks_end, ...] = kv_cache[perm, ...]
|
||||
|
||||
# Construct the right block table
|
||||
# Start from block_id=1 since block_id=0 is considered the null block
|
||||
start_block_idx = 1
|
||||
for i in range(batch_size):
|
||||
num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size)
|
||||
start = start_block_idx
|
||||
end = start + num_blocks_for_seq
|
||||
block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
|
||||
block_table[i, num_blocks_for_seq:] = 0
|
||||
start_block_idx += num_blocks_for_seq
|
||||
|
||||
# Create a realistic slot mapping that corresponds to the block table
|
||||
for i in range(batch_size):
|
||||
token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i])
|
||||
block_indices = token_offsets // block_size
|
||||
token_inter_block_offsets = token_offsets % block_size
|
||||
start = common_attn_metadata.query_start_loc_cpu[i]
|
||||
end = common_attn_metadata.query_start_loc_cpu[i + 1]
|
||||
slot_mapping[start:end] = block_table[
|
||||
i, block_indices
|
||||
] * block_size + token_inter_block_offsets.to(device)
|
||||
|
||||
return kv_cache
|
||||
|
||||
|
||||
class MockAttentionLayer:
|
||||
"""A mock attention layer for testing."""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self._q_scale = torch.tensor(1.0, device=device)
|
||||
self._k_scale = torch.tensor(1.0, device=device)
|
||||
self._v_scale = torch.tensor(1.0, device=device)
|
||||
self._prob_scale = torch.tensor(1.0, device=device)
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
def forward(self, *_args, **_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockMLAAttentionLayer(AttentionLayerBase):
|
||||
"""A mock MLA attention layer for populating static_forward_context."""
|
||||
|
||||
def __init__(self, impl):
|
||||
self.impl = impl
|
||||
|
||||
def get_attn_backend(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def run_attention_backend(
|
||||
backend: AttentionBackendEnum,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor,
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
mock_kv_b_proj,
|
||||
) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
|
||||
builder_cls, impl_cls = try_get_attention_backend(backend)
|
||||
|
||||
# Set the current vllm config so that get_current_vllm_config() works
|
||||
# in the backend implementations
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Instantiate MLA implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
)
|
||||
|
||||
# Process weights to create W_UK_T and W_UV attributes needed by MLA
|
||||
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# Populate static_forward_context with mock attention layers
|
||||
for layer_name in layer_names:
|
||||
vllm_config.compilation_config.static_forward_context[layer_name] = (
|
||||
MockMLAAttentionLayer(impl)
|
||||
)
|
||||
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
# Create mock layer and output buffer
|
||||
mock_layer = MockAttentionLayer(device)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
# NOTE: The query, key, and value are already shaped correctly
|
||||
# in the calling test function.
|
||||
output = impl.forward(
|
||||
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_decode",
|
||||
"small_prefill",
|
||||
"mixed_small",
|
||||
"medium_decode",
|
||||
"medium_prefill",
|
||||
"mixed_medium",
|
||||
"large_decode",
|
||||
"large_prefill",
|
||||
"single_decode",
|
||||
"single_prefill",
|
||||
"spec_decode_small",
|
||||
"spec_decode_medium",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
|
||||
def test_backend_correctness(
|
||||
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""
|
||||
Test that all backends produce similar outputs to a reference implementation
|
||||
using torch.nn.functional.scaled_dot_product_attention.
|
||||
|
||||
This test works by:
|
||||
1. Generating a batch of sequences with specified context and query lengths.
|
||||
2. Computing a ground-truth attention output using torch.sdpa on
|
||||
contiguous Q, K, and V tensors.
|
||||
3. Simulating vLLM's paged KV cache: It takes the context portion of the
|
||||
K/V tensors and manually places them into a paged buffer according to
|
||||
the test's (randomly generated) block table.
|
||||
4. Running each vLLM attention backend with the new queries and the
|
||||
simulated paged KV cache.
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
|
||||
Note: When tensor_parallel_size > 1, we simulate the head partitioning
|
||||
by overriding the model config to use fewer heads, without requiring
|
||||
multiple GPUs. This tests that backends work correctly with different
|
||||
head counts.
|
||||
"""
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
|
||||
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
|
||||
default_block_size = unique_block_sizes[0]
|
||||
required_blocks = sum(
|
||||
(seq_len + default_block_size - 1) // default_block_size
|
||||
for seq_len in batch_spec.seq_lens
|
||||
)
|
||||
# Add 1 for null block at index 0, and some buffer
|
||||
num_gpu_blocks = required_blocks + 1 + 100
|
||||
|
||||
hf_config_override = None
|
||||
if tensor_parallel_size > 1:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
temp_config = ModelConfig(model=model, max_model_len=1)
|
||||
original_num_heads = temp_config.hf_text_config.num_attention_heads
|
||||
original_num_kv_heads = getattr(
|
||||
temp_config.hf_text_config, "num_key_value_heads", None
|
||||
)
|
||||
hf_config_override = {
|
||||
"num_attention_heads": original_num_heads // tensor_parallel_size,
|
||||
}
|
||||
if original_num_kv_heads is not None:
|
||||
hf_config_override["num_key_value_heads"] = max(
|
||||
1, original_num_kv_heads // tensor_parallel_size
|
||||
)
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model,
|
||||
tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
block_size=default_block_size,
|
||||
hf_config_override=hf_config_override,
|
||||
)
|
||||
|
||||
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
|
||||
if is_spec_decode_test:
|
||||
from vllm.config import SpeculativeConfig
|
||||
|
||||
# Get the query length from the batch spec (they should all be uniform)
|
||||
query_len = batch_spec.query_lens[0]
|
||||
# Set num_speculative_tokens to query_len - 1
|
||||
# (since threshold is 1 + num_spec_tokens)
|
||||
# Use ngram method which doesn't require a draft model
|
||||
vllm_config.speculative_config = SpeculativeConfig(
|
||||
method="ngram", num_speculative_tokens=query_len - 1
|
||||
)
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# 1. Setup
|
||||
batch_size = batch_spec.batch_size
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
kv_lora_rank = 512
|
||||
qk_rope_head_dim = 64
|
||||
qk_nope_head_dim = 128
|
||||
v_head_dim = 128
|
||||
total_head_size = kv_lora_rank + qk_rope_head_dim
|
||||
assert kv_lora_rank + qk_rope_head_dim == head_size, (
|
||||
f"MLA dimensions don't match: {total_head_size} != {head_size}"
|
||||
)
|
||||
scale = 1.0 / (total_head_size**0.5)
|
||||
|
||||
# 2. Generate data and compute SDPA reference output for MLA
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
all_sdpa_outputs: list[list[torch.Tensor]] = []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
|
||||
# Create shared MLA weight matrices for consistency across all sequences
|
||||
W_UK = torch.randn(
|
||||
kv_lora_rank, num_q_heads, qk_nope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
W_UV = torch.randn(
|
||||
kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
all_sdpa_outputs.append([])
|
||||
|
||||
for i in range(batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
context_len = s_len - q_len
|
||||
|
||||
# Generate MLA tensors
|
||||
# Q has both nope and rope components:
|
||||
# [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim]
|
||||
q_c = torch.randn(
|
||||
q_len,
|
||||
num_q_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# KV_C (latent K/V): [s_len, kv_lora_rank]
|
||||
kv_c_full = torch.randn(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
|
||||
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
|
||||
k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Determine if this sequence uses the decode pipeline or prefill
|
||||
# pipeline for each backend
|
||||
# NOTE: For spec decode tests with uniform query_len > 1, backends that
|
||||
# support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with
|
||||
# uniform support) will use the decode pipeline (MQA-style), while
|
||||
# backends that only support single-token queries will use the prefill
|
||||
# pipeline (MHA-style). This ensures the reference implementation
|
||||
# matches each backend's actual decode/prefill pipeline path.
|
||||
is_decode = []
|
||||
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
|
||||
builder_cls, _ = try_get_attention_backend(backend)
|
||||
if is_spec_decode_test:
|
||||
query_len_support = getattr(
|
||||
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
|
||||
)
|
||||
supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY
|
||||
is_decode.append(supports_spec)
|
||||
else:
|
||||
threshold = getattr(builder_cls, "reorder_batch_threshold", None)
|
||||
query_len_support = getattr(
|
||||
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
|
||||
)
|
||||
within_threshold = q_len <= threshold if threshold else False
|
||||
if (
|
||||
within_threshold
|
||||
and query_len_support == QueryLenSupport.UNIFORM
|
||||
and i > 0
|
||||
):
|
||||
first_q_len = query_lens[0]
|
||||
within_threshold = q_len == first_q_len
|
||||
is_decode.append(within_threshold)
|
||||
|
||||
# Split q into nope and rope components
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
|
||||
#######################################################
|
||||
# Decode path: MQA-style attention in latent space
|
||||
# Transform q_nope to latent space: q_nope @ W_UK
|
||||
# q_nope: [1, num_heads, qk_nope_head_dim]
|
||||
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
|
||||
ql_nope = torch.einsum(
|
||||
"qnh,lnh->qnl", q_nope, W_UK
|
||||
) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Build MQA attention inputs
|
||||
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
|
||||
# (broadcasted to all heads)
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
|
||||
# Create custom attention mask for decode path:
|
||||
# - Query tokens can attend to all context tokens
|
||||
# - Query tokens can only attend to query tokens up to their position
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
# Apply causal mask only to the query portion (context_len onwards)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, context_len:] = causal_mask
|
||||
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
|
||||
0
|
||||
) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Project back to output space: sdpa_out @ W_UV
|
||||
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, W_UV)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
|
||||
|
||||
#######################################################
|
||||
# Prefill path: MHA-style attention with full sequence
|
||||
# Apply kv_b_proj to the full kv_c tensor
|
||||
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
|
||||
k_nope_full, v_full = kv_nope_full.split([qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
# Build attention inputs for full sequence
|
||||
q_mha = torch.cat([q_nope, q_pe], dim=-1) # [q_len, num_heads, total_dim]
|
||||
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
|
||||
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
|
||||
|
||||
# Create custom attention mask:
|
||||
# - Query tokens can attend to all context tokens
|
||||
# - Query tokens can only attend to query tokens up to their pos
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
# Apply causal mask only to the query portion (context_len onwards)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, context_len:] = causal_mask
|
||||
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
# Single attention call with custom mask
|
||||
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
|
||||
|
||||
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
|
||||
if is_decode[backend_idx]:
|
||||
all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
|
||||
else:
|
||||
all_sdpa_outputs[backend_idx].append(sdpa_out_i_prefill)
|
||||
|
||||
# Inputs for vLLM MLA backends are just the new tokens
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[context_len:]) # New kv_c tokens
|
||||
all_k_pe_vllm.append(k_pe_full[context_len:]) # New k_pe tokens
|
||||
|
||||
# Contextual K/V data used to populate the paged cache (MLA format)
|
||||
kv_c_contexts.append(kv_c_full[:context_len])
|
||||
k_pe_contexts.append(k_pe_full[:context_len])
|
||||
|
||||
# Concatenate all sequences (no reordering needed)
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
|
||||
sdpa_outputs = {}
|
||||
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
|
||||
sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)
|
||||
|
||||
# Create mock kv_b_proj using the same weights as reference implementation
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(
|
||||
input_size=kv_lora_rank,
|
||||
output_size=num_q_heads * (qk_nope_head_dim + v_head_dim),
|
||||
bias=False,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
# Set the mock weights to match our reference implementation
|
||||
# Reshape W_UK and W_UV to match the expected kv_b_proj format
|
||||
# [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim]
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)
|
||||
)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
|
||||
|
||||
# 3. Create metadata and KV caches for each block size
|
||||
# Group backends by block size and test each group
|
||||
metadata_per_block_size = {}
|
||||
kv_cache_per_block_size = {}
|
||||
|
||||
for block_size in unique_block_sizes:
|
||||
# Create metadata for this block size
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, block_size, device
|
||||
)
|
||||
|
||||
# Pad block table to meet requirement:
|
||||
# block_num % (128 / block_size) == 0
|
||||
required_divisor = int(128 / block_size)
|
||||
current_block_num = common_attn_metadata.block_table_tensor.shape[1]
|
||||
if current_block_num % required_divisor != 0:
|
||||
# Pad to next multiple of required_divisor
|
||||
padded_block_num = (
|
||||
(current_block_num + required_divisor - 1) // required_divisor
|
||||
) * required_divisor
|
||||
padding_cols = padded_block_num - current_block_num
|
||||
padding = torch.zeros(
|
||||
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
common_attn_metadata.block_table_tensor = torch.cat(
|
||||
[common_attn_metadata.block_table_tensor, padding], dim=1
|
||||
)
|
||||
|
||||
metadata_per_block_size[block_size] = common_attn_metadata
|
||||
|
||||
# Create KV cache for this block size
|
||||
required_blocks_for_size = sum(
|
||||
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
|
||||
)
|
||||
num_blocks_for_size = required_blocks_for_size + 1 + 100
|
||||
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
k_pe_contexts=k_pe_contexts,
|
||||
block_size=block_size,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_blocks=num_blocks_for_size,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=True,
|
||||
)
|
||||
kv_cache_per_block_size[block_size] = kv_cache
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
failures = []
|
||||
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||
# Skip backends that don't support spec decode for spec decode tests
|
||||
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
|
||||
continue
|
||||
|
||||
# Get the appropriate block_size, metadata, and cache for this backend
|
||||
block_size = BACKEND_BLOCK_SIZES[backend_name]
|
||||
common_attn_metadata = metadata_per_block_size[block_size]
|
||||
kv_cache = kv_cache_per_block_size[block_size]
|
||||
|
||||
# Create kv_cache_spec with the correct block_size for this backend
|
||||
backend_kv_cache_spec = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
),
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
)
|
||||
|
||||
backend_output = run_attention_backend(
|
||||
backend_name,
|
||||
backend_kv_cache_spec,
|
||||
["placeholder"],
|
||||
vllm_config,
|
||||
device,
|
||||
common_attn_metadata,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
kv_lora_rank,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
v_head_dim,
|
||||
mock_kv_b_proj,
|
||||
)
|
||||
|
||||
# Use backend_idx to get the correct SDPA output for this backend
|
||||
expected_output = sdpa_outputs[backend_name]
|
||||
|
||||
# Check shape and dtype consistency
|
||||
try:
|
||||
assert backend_output.shape == expected_output.shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {expected_output.shape}"
|
||||
)
|
||||
assert backend_output.dtype == expected_output.dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {expected_output.dtype}"
|
||||
)
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values"
|
||||
)
|
||||
|
||||
# Check numerical similarity
|
||||
rtol = 1e-2
|
||||
atol = 5e-1
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
|
||||
).item()
|
||||
all_close = torch.allclose(
|
||||
backend_output, expected_output, rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
assert all_close, (
|
||||
f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
||||
)
|
||||
except AssertionError as e:
|
||||
failures.append(str(e))
|
||||
|
||||
# Report all failures at once
|
||||
if failures:
|
||||
# Create a summary for the single-line failure message
|
||||
backend_names = []
|
||||
for f in failures:
|
||||
if "[AttentionBackendEnum." in f:
|
||||
backend_name = f.split("[")[1].split("]")[0]
|
||||
backend_names.append(backend_name)
|
||||
|
||||
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
|
||||
detailed_msg = "\n".join(failures)
|
||||
pytest.fail(f"{summary}\n{detailed_msg}")
|
||||
340
tests/v1/attention/test_rocm_attention_backends_selection.py
Normal file
340
tests/v1/attention/test_rocm_attention_backends_selection.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for attention backend selectors."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# ROCm-specific attention backend selection tests
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="ROCm-specific tests"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_config():
|
||||
"""Create a mock VllmConfig for testing."""
|
||||
config = MagicMock()
|
||||
config.model_config.dtype = torch.float16
|
||||
config.model_config.hf_config.architectures = ["LlamaForCausalLM"]
|
||||
config.cache_config.block_size = 16
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_on_gfx9():
|
||||
"""Mock the on_gfx9 function to return True."""
|
||||
with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_vars, selected_backend, expected_backend_path",
|
||||
[
|
||||
# Test Case: Explicit FLEX_ATTENTION backend
|
||||
(
|
||||
{},
|
||||
"FLEX_ATTENTION",
|
||||
AttentionBackendEnum.FLEX_ATTENTION.get_path(),
|
||||
),
|
||||
# Test Case 1: Default (no env vars, no explicit backend)
|
||||
(
|
||||
{},
|
||||
None,
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 2: Explicit TRITON_ATTN backend
|
||||
(
|
||||
{},
|
||||
"TRITON_ATTN",
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 3: Explicit ROCM_ATTN backend
|
||||
(
|
||||
{},
|
||||
"ROCM_ATTN",
|
||||
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 4: Explicit ROCM_AITER_FA backend
|
||||
(
|
||||
{},
|
||||
"ROCM_AITER_FA",
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
|
||||
),
|
||||
# Test Case 5: Explicit ROCM_AITER_UNIFIED_ATTN backend
|
||||
(
|
||||
{},
|
||||
"ROCM_AITER_UNIFIED_ATTN",
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 6: VLLM_ROCM_USE_AITER=1
|
||||
# (defaults to AITER FA when MHA not explicitly disabled)
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
|
||||
),
|
||||
# Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "1"},
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
|
||||
),
|
||||
# Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1
|
||||
(
|
||||
{
|
||||
"VLLM_ROCM_USE_AITER": "1",
|
||||
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": "1",
|
||||
},
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
|
||||
(
|
||||
{"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
"TRITON_ATTN",
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
|
||||
# (explicitly disabled)
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
|
||||
None,
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
"ROCM_ATTN",
|
||||
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_standard_attention_backend_selection(
|
||||
env_vars,
|
||||
selected_backend,
|
||||
expected_backend_path,
|
||||
mock_vllm_config,
|
||||
mock_on_gfx9,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test standard attention backend selection with various configurations."""
|
||||
# Set environment variables
|
||||
for key, value in env_vars.items():
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
# Import after setting env vars to ensure they're picked up
|
||||
# Reload envs to pick up new environment variables
|
||||
import importlib
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
importlib.reload(envs)
|
||||
|
||||
# Convert string backend to enum if provided
|
||||
backend_enum = None
|
||||
if selected_backend:
|
||||
backend_enum = getattr(AttentionBackendEnum, selected_backend)
|
||||
|
||||
# Get the backend class path
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
block_size=16,
|
||||
use_mla=False,
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
)
|
||||
assert backend_path == expected_backend_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_vars, selected_backend, block_size, expected_backend_path, should_raise",
|
||||
[
|
||||
# Test Case 1: TRITON_MLA with block_size != 1
|
||||
(
|
||||
{},
|
||||
"TRITON_MLA",
|
||||
16,
|
||||
AttentionBackendEnum.TRITON_MLA.get_path(),
|
||||
False,
|
||||
),
|
||||
# Test Case 2: TRITON_MLA with block_size == 1 (should raise)
|
||||
(
|
||||
{},
|
||||
"TRITON_MLA",
|
||||
1,
|
||||
None,
|
||||
True,
|
||||
),
|
||||
# Test Case 3: ROCM_AITER_MLA with block_size == 1
|
||||
(
|
||||
{},
|
||||
"ROCM_AITER_MLA",
|
||||
1,
|
||||
AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
|
||||
False,
|
||||
),
|
||||
# Test Case 4: ROCM_AITER_MLA with block_size != 1 (should raise)
|
||||
(
|
||||
{},
|
||||
"ROCM_AITER_MLA",
|
||||
16,
|
||||
AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
|
||||
False,
|
||||
),
|
||||
# Test Case 5: VLLM_ROCM_USE_AITER=1 with block_size == 1
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
None,
|
||||
1,
|
||||
AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
|
||||
False,
|
||||
),
|
||||
# Test Case 6: VLLM_ROCM_USE_AITER=1 with block_size == 16
|
||||
# (should use ROCM_AITER_MLA now, as it supports block_size 16)
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
None,
|
||||
16,
|
||||
AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
|
||||
False,
|
||||
),
|
||||
# Test Case 7: VLLM_ROCM_USE_AITER=1 + explicit TRITON_MLA
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
"TRITON_MLA",
|
||||
16,
|
||||
AttentionBackendEnum.TRITON_MLA.get_path(),
|
||||
False,
|
||||
),
|
||||
# Test Case 8: Explicit ROCM_AITER_TRITON_MLA
|
||||
(
|
||||
{},
|
||||
"ROCM_AITER_TRITON_MLA",
|
||||
16,
|
||||
AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path(),
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mla_backend_selection(
|
||||
env_vars,
|
||||
selected_backend,
|
||||
block_size,
|
||||
expected_backend_path,
|
||||
should_raise,
|
||||
mock_vllm_config,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test MLA backend selection with various configurations."""
|
||||
# Set environment variables
|
||||
for key, value in env_vars.items():
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
# Import after setting env vars
|
||||
# Reload envs
|
||||
import importlib
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
importlib.reload(envs)
|
||||
|
||||
# Mock is_aiter_mla_enabled based on env vars and block_size
|
||||
aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1"
|
||||
|
||||
mock_rocm_ops = MagicMock()
|
||||
mock_rocm_ops.is_mla_enabled.return_value = aiter_enabled
|
||||
mock_aiter_module = MagicMock()
|
||||
mock_aiter_module.rocm_aiter_ops = mock_rocm_ops
|
||||
|
||||
with patch.dict("sys.modules", {"vllm._aiter_ops": mock_aiter_module}):
|
||||
# Convert string backend to enum if provided
|
||||
backend_enum = None
|
||||
if selected_backend:
|
||||
backend_enum = getattr(AttentionBackendEnum, selected_backend)
|
||||
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
block_size=block_size,
|
||||
use_mla=True,
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
)
|
||||
else:
|
||||
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
block_size=block_size,
|
||||
use_mla=True,
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
)
|
||||
assert backend_path == expected_backend_path
|
||||
|
||||
|
||||
def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
||||
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
# Mock on_gfx9 to return False
|
||||
with (
|
||||
patch("vllm.platforms.rocm.on_gfx9", return_value=False),
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
match="only supported on gfx9",
|
||||
),
|
||||
):
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
block_size=16,
|
||||
use_mla=False,
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_not_supported(mock_vllm_config):
|
||||
"""Test that sparse attention is not supported on ROCm."""
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
|
||||
):
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=None,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
block_size=16,
|
||||
use_mla=False,
|
||||
has_sink=False,
|
||||
use_sparse=True,
|
||||
)
|
||||
566
tests/v1/attention/test_sparse_mla_backends.py
Normal file
566
tests/v1/attention/test_sparse_mla_backends.py
Normal file
@@ -0,0 +1,566 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for the FlashMLA sparse backend utilities."""
|
||||
|
||||
import math
|
||||
from types import MethodType, SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_mla_backends import (
|
||||
BATCH_SPECS,
|
||||
BatchSpec,
|
||||
MockAttentionLayer,
|
||||
create_and_prepopulate_kv_cache,
|
||||
)
|
||||
from tests.v1.attention.utils import (
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
name: BATCH_SPECS[name]
|
||||
for name in [
|
||||
"mixed_small",
|
||||
"mixed_medium",
|
||||
"small_prefill",
|
||||
"medium_prefill",
|
||||
"single_prefill",
|
||||
]
|
||||
}
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(
|
||||
seq_lens=[1024] * 2, query_lens=[256] * 2
|
||||
)
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
|
||||
seq_lens=[256] * 2, query_lens=[256] * 2
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
|
||||
|
||||
# The first kv_lora_rank bytes store FP8 latent values with one scale per
|
||||
# 128 element tile written as float32 right after the latent payload.
|
||||
scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4]
|
||||
latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device)
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = tile_start + 128
|
||||
ops.convert_fp8(
|
||||
latent[tile_start:tile_end],
|
||||
cache_slice[tile_start:tile_end],
|
||||
float(scales[tile_idx].item()),
|
||||
kv_dtype="fp8",
|
||||
)
|
||||
latent = latent.to(dtype)
|
||||
|
||||
rope_offset = kv_lora_rank // 2 + 8
|
||||
rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim]
|
||||
return latent, rope_vals.clone()
|
||||
|
||||
|
||||
def _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
|
||||
|
||||
if kv_c.numel() == 0:
|
||||
return kv_c.clone(), k_pe.clone()
|
||||
|
||||
kv_lora_rank = kv_c.shape[-1]
|
||||
rope_dim = k_pe.shape[-1]
|
||||
num_tokens = kv_c.shape[0]
|
||||
num_blocks = max(1, math.ceil(num_tokens / block_size))
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
|
||||
tmp_cache = torch.zeros(
|
||||
num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device
|
||||
)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device)
|
||||
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale
|
||||
)
|
||||
|
||||
dequant_kv_c = torch.empty_like(kv_c)
|
||||
dequant_k_pe = torch.empty_like(k_pe)
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
slot = slot_mapping[token_idx].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
cache_slice = tmp_cache[block_idx, block_offset]
|
||||
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype
|
||||
)
|
||||
dequant_kv_c[token_idx] = latent
|
||||
dequant_k_pe[token_idx] = rope_vals
|
||||
|
||||
return dequant_kv_c, dequant_k_pe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_sparse_backend_decode_correctness(
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
|
||||
|
||||
# Model hyper-parameters (kept intentionally small for the unit test)
|
||||
num_heads = 128
|
||||
kv_lora_rank = 512
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
v_head_dim = 128
|
||||
head_size = kv_lora_rank + qk_rope_head_dim
|
||||
topk_tokens = 2048
|
||||
|
||||
max_seqlen = max(batch_spec.seq_lens)
|
||||
total_cache_tokens = sum(batch_spec.seq_lens)
|
||||
block_size = 64
|
||||
|
||||
# Note: We use TP=1 to avoid multi-GPU requirements in CI.
|
||||
# The test simulates head partitioning via mocked methods below.
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=max_seqlen,
|
||||
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size,
|
||||
hf_config_override={
|
||||
"index_topk": topk_tokens,
|
||||
"attn_module_list_cfg": [{"topk_tokens": topk_tokens}],
|
||||
},
|
||||
)
|
||||
model_config = vllm_config.model_config
|
||||
model_config.hf_text_config = SimpleNamespace(
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
model_type="deepseek_v2",
|
||||
)
|
||||
model_config.dtype = dtype
|
||||
model_config.get_num_attention_heads = MethodType(
|
||||
lambda self, parallel_config: max(1, num_heads // tensor_parallel_size),
|
||||
model_config,
|
||||
)
|
||||
model_config.get_num_kv_heads = MethodType(
|
||||
lambda self, parallel_config: 1, model_config
|
||||
)
|
||||
model_config.get_head_size = MethodType(lambda self: head_size, model_config)
|
||||
model_config.get_sliding_window = MethodType(lambda self: None, model_config)
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = 1.0 / math.sqrt(head_size)
|
||||
|
||||
# Shared MLA projection weights to keep reference and backend in sync
|
||||
W_UK = torch.randn(
|
||||
kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Build synthetic decode-only workload
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
reference_outputs = []
|
||||
|
||||
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
for i in range(batch_spec.batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
ctx_len = s_len - q_len
|
||||
|
||||
q_c = torch.rand(
|
||||
q_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
k_pe_full.squeeze(1),
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, ctx_len:] = causal_mask
|
||||
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
|
||||
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[ctx_len:])
|
||||
all_k_pe_vllm.append(k_pe_full[ctx_len:])
|
||||
kv_c_contexts.append(kv_c_full[: ctx_len + 1])
|
||||
k_pe_contexts.append(k_pe_full[: ctx_len + 1])
|
||||
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
|
||||
sdpa_reference = torch.cat(reference_outputs, dim=0)
|
||||
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
vllm_config.model_config.hf_config.index_topk = topk_tokens
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
vllm_config.cache_config.block_size,
|
||||
device,
|
||||
arange_block_indices=True,
|
||||
)
|
||||
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
k_pe_contexts=k_pe_contexts,
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=False,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
builder_cls = FlashMLASparseBackend.get_builder_cls()
|
||||
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
|
||||
metadata = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
|
||||
starts[:-1], seg_lengths
|
||||
)
|
||||
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
|
||||
prefix_lengths = seq_lengths - seg_lengths
|
||||
positions += np.repeat(prefix_lengths, seg_lengths)
|
||||
|
||||
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
|
||||
topk = metadata.topk_tokens
|
||||
debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
token_positions = pos_gpu.unsqueeze(1)
|
||||
causal_mask = debug_indices <= token_positions
|
||||
debug_indices = torch.where(
|
||||
causal_mask, debug_indices, torch.full_like(debug_indices, -1)
|
||||
)
|
||||
|
||||
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
|
||||
# buffer, so emulate that contract with a simple namespace mock.
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)
|
||||
)
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(
|
||||
input_size=kv_lora_rank,
|
||||
output_size=num_heads * (qk_nope_head_dim + v_head_dim),
|
||||
bias=False,
|
||||
).to(device=device, dtype=dtype)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(
|
||||
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
backend_output = impl.forward(
|
||||
layer,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
metadata,
|
||||
output=out_buffer,
|
||||
)
|
||||
|
||||
assert backend_output.shape == sdpa_reference.shape
|
||||
assert backend_output.dtype == sdpa_reference.dtype
|
||||
assert torch.isfinite(backend_output).all()
|
||||
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
|
||||
|
||||
|
||||
def _triton_convert_reference_impl(
|
||||
req_ids: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
token_indices: torch.Tensor,
|
||||
block_size: int,
|
||||
num_topk_tokens: int,
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Reference implementation for triton_convert_req_index_to_global_index."""
|
||||
num_tokens = req_ids.shape[0]
|
||||
max_blocks_per_req = block_table.shape[1]
|
||||
result = torch.empty(
|
||||
num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device
|
||||
)
|
||||
|
||||
for token_id in range(num_tokens):
|
||||
req_id = req_ids[token_id].item()
|
||||
|
||||
# Determine if this token uses workspace or paged cache
|
||||
use_prefill_workspace = False
|
||||
workspace_start = 0
|
||||
if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None:
|
||||
assert prefill_workspace_starts is not None
|
||||
prefill_req_id = prefill_workspace_request_ids[token_id].item()
|
||||
if prefill_req_id >= 0:
|
||||
use_prefill_workspace = True
|
||||
workspace_start = prefill_workspace_starts[prefill_req_id].item()
|
||||
|
||||
for idx_id in range(num_topk_tokens):
|
||||
token_idx = token_indices[token_id, idx_id].item()
|
||||
|
||||
if token_idx == -1:
|
||||
result[token_id, idx_id] = -1
|
||||
elif use_prefill_workspace:
|
||||
# Prefill + using prefill workspace: map to workspace offset
|
||||
result[token_id, idx_id] = workspace_start + token_idx
|
||||
else:
|
||||
# Decode: map to paged cache
|
||||
block_id = token_idx // block_size
|
||||
if block_id >= max_blocks_per_req:
|
||||
result[token_id, idx_id] = -1
|
||||
else:
|
||||
block_num = block_table[req_id, block_id].item()
|
||||
offset = token_idx % block_size
|
||||
result[token_id, idx_id] = block_num * block_size + offset
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16, 64, 128])
|
||||
@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_triton_convert_req_index_to_global_index_decode_only(
|
||||
block_size, num_topk_tokens
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
num_tokens = 8
|
||||
num_requests = 4
|
||||
max_blocks_per_req = 10
|
||||
|
||||
req_id = torch.randint(
|
||||
0, num_requests, (num_tokens,), dtype=torch.int32, device=device
|
||||
)
|
||||
block_table = torch.randint(
|
||||
0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
token_indices = torch.randint(
|
||||
0,
|
||||
block_size * max_blocks_per_req,
|
||||
(num_tokens, num_topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set some to -1 to test masking
|
||||
token_indices[0, :10] = -1
|
||||
token_indices[3, 50:60] = -1
|
||||
|
||||
# Set some to out of bounds
|
||||
token_indices[2, 100:110] = max_blocks_per_req * block_size
|
||||
token_indices[6, 150:160] = max_blocks_per_req * block_size
|
||||
|
||||
result = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
)
|
||||
|
||||
reference_result = _triton_convert_reference_impl(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
block_size,
|
||||
num_topk_tokens,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
|
||||
device = torch.device("cuda")
|
||||
num_requests = 4
|
||||
max_blocks_per_req = 8
|
||||
num_topk_tokens = 128
|
||||
|
||||
# First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3)
|
||||
req_id = torch.tensor(
|
||||
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device
|
||||
)
|
||||
prefill_workspace_request_ids = torch.tensor(
|
||||
[-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100
|
||||
prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device)
|
||||
|
||||
block_table = torch.randint(
|
||||
0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
|
||||
)
|
||||
token_indices = torch.randint(
|
||||
0,
|
||||
block_size * max_blocks_per_req,
|
||||
(req_id.shape[0], num_topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set some to -1 to test masking
|
||||
token_indices[0, :10] = -1
|
||||
token_indices[3, 50:60] = -1
|
||||
|
||||
# Set some to out of bounds
|
||||
token_indices[2, 100:110] = max_blocks_per_req * block_size
|
||||
token_indices[6, 150:160] = max_blocks_per_req * block_size
|
||||
|
||||
result = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
HAS_PREFILL_WORKSPACE=True,
|
||||
prefill_workspace_request_ids=prefill_workspace_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
reference_result = _triton_convert_reference_impl(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
block_size,
|
||||
num_topk_tokens,
|
||||
HAS_PREFILL_WORKSPACE=True,
|
||||
prefill_workspace_request_ids=prefill_workspace_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,max_buf,expected",
|
||||
[
|
||||
# Basic split: totals per chunk ≤ max_buf
|
||||
(torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would overflow
|
||||
(torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]),
|
||||
# All requests fit in a single chunk
|
||||
(torch.tensor([1, 1, 1]), 10, [(0, 3)]),
|
||||
# Large buffer
|
||||
(torch.tensor([4, 4, 4]), 100, [(0, 3)]),
|
||||
],
|
||||
)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf)
|
||||
assert out == expected
|
||||
352
tests/v1/attention/utils.py
Normal file
352
tests/v1/attention/utils.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for attention-related v1 tests."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionImpl
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
DeviceConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.model import ModelDType
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSpec:
|
||||
"""Specification for a batch configuration (workload shape only)."""
|
||||
|
||||
seq_lens: list[int]
|
||||
query_lens: list[int]
|
||||
|
||||
name: str = "unnamed"
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return len(self.seq_lens)
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.seq_lens) == len(self.query_lens)
|
||||
|
||||
def compute_num_tokens(self):
|
||||
return sum(self.query_lens)
|
||||
|
||||
|
||||
def create_common_attn_metadata(
|
||||
batch_spec: BatchSpec,
|
||||
block_size: int,
|
||||
device: torch.device,
|
||||
max_block_idx: int = 1000,
|
||||
arange_block_indices: bool = False,
|
||||
) -> CommonAttentionMetadata:
|
||||
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
|
||||
# Create query start locations
|
||||
query_start_loc = torch.zeros(
|
||||
batch_spec.batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
query_start_loc[1:] = torch.tensor(
|
||||
batch_spec.query_lens, dtype=torch.int32, device=device
|
||||
).cumsum(0)
|
||||
query_start_loc_cpu = query_start_loc.cpu()
|
||||
num_tokens = batch_spec.compute_num_tokens()
|
||||
|
||||
# Create sequence lengths
|
||||
seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device)
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
|
||||
# Create computed tokens (context length for each sequence)
|
||||
context_lens = [
|
||||
batch_spec.seq_lens[i] - batch_spec.query_lens[i]
|
||||
for i in range(batch_spec.batch_size)
|
||||
]
|
||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||
|
||||
# Create block table and slot mapping
|
||||
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
||||
if arange_block_indices:
|
||||
num_blocks = batch_spec.batch_size * max_blocks
|
||||
block_table_tensor = torch.arange(
|
||||
num_blocks, dtype=torch.int32, device=device
|
||||
).view(batch_spec.batch_size, max_blocks)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device).view(
|
||||
num_tokens
|
||||
)
|
||||
else:
|
||||
block_table_tensor = torch.randint(
|
||||
0,
|
||||
max_block_idx,
|
||||
(batch_spec.batch_size, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
slot_mapping = torch.randint(
|
||||
0, max_block_idx, (num_tokens,), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Calculate max query length
|
||||
max_query_len = max(batch_spec.query_lens)
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=batch_spec.batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
|
||||
def try_get_attention_backend(
|
||||
backend: AttentionBackendEnum,
|
||||
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
|
||||
"""Try to get the attention backend class, skipping test if not found."""
|
||||
try:
|
||||
backend_class = backend.get_class()
|
||||
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
|
||||
except ImportError as e:
|
||||
pytest.skip(f"{backend.name} not available: {e}")
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
|
||||
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:
|
||||
"""Create a FullAttentionSpec from ModelParams only."""
|
||||
return FullAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
),
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
)
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
model_name: str = "meta-llama/Meta-Llama-3-8B",
|
||||
tensor_parallel_size: int = 1,
|
||||
max_model_len: int = 1024,
|
||||
dtype: ModelDType | torch.dtype = "auto",
|
||||
num_gpu_blocks: int = 1000,
|
||||
block_size: int = 16,
|
||||
max_num_seqs: int = 256,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_chunked_prefill: bool = True,
|
||||
add_mock_model_methods: bool = True,
|
||||
hf_config_override: dict | None = None,
|
||||
) -> VllmConfig:
|
||||
"""Create a VllmConfig for testing with reasonable defaults."""
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
trust_remote_code=False,
|
||||
dtype=dtype,
|
||||
seed=0,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
cache_dtype="auto",
|
||||
swap_space=0,
|
||||
)
|
||||
# Set cache blocks for testing
|
||||
# (these may be set during initialization normally)
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
cache_config.num_cpu_blocks = 0
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
device_config = DeviceConfig()
|
||||
load_config = LoadConfig()
|
||||
compilation_config = CompilationConfig()
|
||||
|
||||
if add_mock_model_methods:
|
||||
# Add mock methods to satisfy backends that need them
|
||||
# This is a workaround because tests don't build full, real models,
|
||||
# but some backends expect to query the model for layer-specific
|
||||
# parameters
|
||||
import types
|
||||
|
||||
model_config.get_num_layers = types.MethodType(lambda self: 1, model_config)
|
||||
model_config.get_sliding_window_for_layer = types.MethodType(
|
||||
lambda self, i: None, model_config
|
||||
)
|
||||
model_config.get_logits_soft_cap_for_layer = types.MethodType(
|
||||
lambda self, i: 0.0, model_config
|
||||
)
|
||||
model_config.get_sm_scale_for_layer = types.MethodType(
|
||||
lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config
|
||||
)
|
||||
|
||||
if hf_config_override:
|
||||
model_config.hf_config.update(hf_config_override)
|
||||
|
||||
return VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
|
||||
|
||||
def create_dummy_kv_cache(
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int = 100,
|
||||
) -> torch.Tensor:
|
||||
"""Create a dummy KV cache tensor for testing."""
|
||||
kv_cache = torch.randn(
|
||||
num_blocks,
|
||||
2, # K and V
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
return kv_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict # compilation config
|
||||
specific_gpu_arch: tuple | None = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
full_cg_backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3": BackendConfig(
|
||||
name="FA3",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0),
|
||||
),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA": BackendConfig(
|
||||
name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0),
|
||||
),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA": BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(10, 0),
|
||||
),
|
||||
# FlashInfer MLA on Blackwell
|
||||
"FlashInferMLA": BackendConfig(
|
||||
name="FlashInferMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(10, 0),
|
||||
),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA": BackendConfig(
|
||||
name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0),
|
||||
),
|
||||
# FA2
|
||||
"FA2": BackendConfig(
|
||||
name="FA2",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
),
|
||||
# Triton Attention
|
||||
"TritonAttn": BackendConfig(
|
||||
name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
),
|
||||
# FlashInfer
|
||||
"FlashInfer": BackendConfig(
|
||||
name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
),
|
||||
"RocmAttn": BackendConfig(
|
||||
name="RocmAttn",
|
||||
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
),
|
||||
}
|
||||
Reference in New Issue
Block a user