Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,742 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest import mock
import pytest
import torch
from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def _create_proposer(
method: str,
num_speculative_tokens: int,
speculative_token_tree: list[tuple[int, ...]] | None = None,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
# Choose model directory based on method
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
spec_token_tree_str = None
if speculative_token_tree is not None:
assert num_speculative_tokens == len(speculative_token_tree)
spec_token_tree_str = str(speculative_token_tree)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=draft_model_dir,
method=method,
num_speculative_tokens=num_speculative_tokens,
speculative_token_tree=spec_token_tree_str,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
def test_prepare_next_token_ids():
"""
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
Each will produce a device tensor of next_token_ids, taking as input
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
or the CPU python list[list[int]] with the rejected tokens removed.
"""
device = torch.device(current_platform.device_type)
num_requests = 4
num_speculative_tokens = 4
batch_spec = BatchSpec(
seq_lens=[num_speculative_tokens + 1] * num_requests,
query_lens=[num_speculative_tokens + 1] * num_requests,
)
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
mock_requests = {}
for req_id in req_ids:
mock_request = mock.MagicMock(spec=CachedRequestState)
# Each request will have a backup next token id of 10, 20, 30, 40
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
mock_request.num_computed_tokens = 0
mock_requests[req_id] = mock_request
# explicitly discard the last request
discarded_req_mask = torch.tensor(
[False, False, False, True], dtype=torch.bool, device=device
)
sampled_token_ids = [
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
[0, 1, 2, 3, 4], # all accepted, "4" sampled
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
[0, 1, 2, -1, -1], # explicitly discarded, sampling should be ignored
]
sampled_token_ids_tensor = torch.tensor(
sampled_token_ids, dtype=torch.int32, device=device
)
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
for i in range(len(sampled_token_ids_cpu)):
if discarded_req_mask[i]:
sampled_token_ids_cpu[i] = []
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(
expected_next_token_ids_cpu, dtype=torch.int32, device=device
)
proposer = _create_proposer("eagle", num_speculative_tokens)
next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
sampled_token_ids_cpu,
mock_requests,
mock_input_batch,
mock_num_scheduled_tokens,
)
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)
next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,
discarded_req_mask,
)
)
assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor)
assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
def test_prepare_inputs():
"""
cu_target_query_lens: [0, a, a + b, a + b + c]
num_rejected_tokens: [n1, n2, n3]
num_tokens_per_req: [a - n1, b - n2, c - n3]
cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
token_indices: [0, 1, ..., a - n1 - 1,
a, a + 1, ..., a + b - n2 - 1,
a + b, a + b + 1, ..., a + b + c - n3 - 1]
"""
device = torch.device(current_platform.device_type)
# q1 = 4, q2 = 7, q3 = 5
# n1 = 1, n2 = 3, n3 = 2
batch_spec = BatchSpec(
seq_lens=[4, 7, 5],
query_lens=[4, 7, 5],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
# from the previous iteration, and the last token is the bonus token sampled
# from the base model.
num_draft_tokens = [3, 6, 4] # one less than query_lens
# num rejected tokens is [1, 3, 2]
ACCEPT_TOKEN = 0
BONUS_TOKEN = 1
REJECT_TOKEN = -1
sampled_token_ids = [
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
[
ACCEPT_TOKEN,
ACCEPT_TOKEN,
ACCEPT_TOKEN,
REJECT_TOKEN,
REJECT_TOKEN,
REJECT_TOKEN,
BONUS_TOKEN,
],
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
]
sampled_token_ids = [
[i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids
]
# Expected calculations:
# query_len_per_req = [4, 7, 5]
# num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens)
# Expected cumulative counts: [0, 3, 7, 10]
expected_cu_num_tokens = torch.tensor(
[0, 3, 7, 10], dtype=torch.int32, device=device
)
# Expected token indices (mapped from original positions):
# First request: indices 0, 1, 2 (keeping first 3 from positions 0-3)
# Second request: indices 4, 5, 6, 7 (keeping first 4 from positions 4-10)
# Third request: indices 11, 12, 13 (keeping first 3 from positions 11-15)
expected_token_indices = torch.tensor(
[
0,
1,
2, # First request: 3 tokens (4-1)
4,
5,
6,
7, # Second request: 4 tokens (7-3)
11,
12,
13, # Third request: 3 tokens (5-2)
],
dtype=torch.int32,
device=device,
)
proposer = _create_proposer("eagle", 1)
updated_metadata, token_indices = proposer.prepare_inputs(
common_attn_metadata, sampled_token_ids, num_draft_tokens
)
assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
assert torch.equal(token_indices, expected_token_indices)
def test_prepare_inputs_padded():
"""
Input scenario is 3 requests with num_speculative_tokens == 2 and:
- Request 1: query_len = 3, rejected = 1
- Request 2: query_len = 3, rejected = 0
- Request 3: query_len = 3, rejected = 2
Expected outputs:
token_indices_to_sample: [1, 5, 6]
Reason: After accounting for rejections, these are the valid token positions
from the original indices to sample from.
"""
device = torch.device(current_platform.device_type)
expected_token_indices_to_sample = torch.tensor(
[1, 5, 6], dtype=torch.int32, device=device
)
num_speculative_tokens = 2
batch_spec = BatchSpec(
seq_lens=[3, 3, 3],
query_lens=[3, 3, 3],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
expected_query_start_loc = torch.tensor(
[0, 3, 6, 9], dtype=torch.int32, device=device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids=[[0] * num_speculative_tokens] * 3,
device=device,
)
# num_rejected_tokens = [1, 0, 2]
# num_draft_tokens = [2, 2, 2]
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
valid_sampled_tokens_count = torch.tensor(
[2, 3, 1], dtype=torch.int32, device=device
)
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
def test_load_model(
mock_get_model,
mock_get_layers,
mock_get_pp_group,
method,
attn_backend,
pp_size,
use_distinct_embed_tokens,
use_distinct_lm_head,
monkeypatch,
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Setup draft model mock
mock_model = mock.MagicMock()
mock_model.model = mock.MagicMock()
mock_model.has_own_embed_tokens = use_distinct_embed_tokens
if use_distinct_embed_tokens:
mock_model.model.embed_tokens = mock.MagicMock()
mock_model.has_own_lm_head = use_distinct_lm_head
if use_distinct_lm_head:
mock_model.lm_head = mock.MagicMock()
mock_get_model.return_value = mock_model
# Setup mocks for attention layers
target_attn_layers = {
"target_attn_1": mock.MagicMock(),
"target_attn_2": mock.MagicMock(),
}
target_indx_layers: dict[str, mock.MagicMock] = {}
# Draft model has one extra attention layer compared to target model
all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()}
all_indx_layers: dict[str, mock.MagicMock] = {}
# Make mock_get_layers return different values for each call
mock_get_layers.side_effect = [
target_attn_layers,
target_indx_layers,
all_attn_layers,
all_indx_layers,
]
# Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = pp_size
mock_get_pp_group.return_value = mock_pp_group
# Set up the target model mock with a custom class so that
# isinstance() checks match the expected type.
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.lm_head = mock.MagicMock()
target_model.model.embed_tokens = mock.MagicMock()
from vllm.model_executor.models import SupportsMultiModal
assert not isinstance(target_model, SupportsMultiModal)
# Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8)
# Call the method under test
proposer.load_model(target_model)
# Verify common interactions
mock_get_model.assert_called_once()
# Verify that the lm head is set correctly
if use_distinct_lm_head:
assert proposer.model.lm_head is not target_model.lm_head
else:
assert proposer.model.lm_head is target_model.lm_head
# Verify that the embed tokens are set correctly
# If pp_size is > 1, the embed tokens should be distinct
if pp_size > 1 or use_distinct_embed_tokens:
assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
else:
assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "TREE_ATTN":
pytest.skip(
"TREE_ATTN is tested separately in test_propose_tree"
"because it requires special input mocking."
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Use GPU device
device = torch.device(current_platform.device_type)
# Setup test parameters
batch_size = 2
seq_len_1 = 5
seq_len_2 = 3
total_tokens = seq_len_1 + seq_len_2
vocab_size = 100
seq_lens = [seq_len_1, seq_len_2]
# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens)
# Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size
# Helper to create deterministic logits that will produce specific tokens
def create_deterministic_logits(token_ids):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
for i, token_id in enumerate(token_ids):
logits[i, token_id] = 100.0
return logits
# We mock a model that returns deterministic logits
# Sequence 1: 42, 43, 44, ...
# Sequence 2: 60, 61, 62, ...
base_token_ids = [42, 60]
# Skip loading the model and replace it with a mock directly
# Create the mock model with deterministic outputs
model_mock = mock.MagicMock()
# Setup for model forward calls
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
# First call uses all tokens
h_logits = torch.zeros(total_tokens, hidden_size, device=device)
h_states = torch.zeros(total_tokens, hidden_size, device=device)
else:
# Subsequent calls use batch_size tokens
h_logits = torch.zeros(batch_size, hidden_size, device=device)
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append((h_logits, h_states))
# For single token case, we only need the first item;
# for multi-token, we need the sequence
if num_speculative_tokens == 1:
model_mock.return_value = forward_returns[0]
else:
model_mock.side_effect = forward_returns
# Setup for compute_logits calls
logits_returns = []
for i in range(num_speculative_tokens):
# For each call, increment the base token IDs
current_tokens = [base_id + i for base_id in base_token_ids]
logits_returns.append(create_deterministic_logits(current_tokens))
if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = logits_returns[0]
else:
model_mock.compute_logits.side_effect = logits_returns
# Assign the mock to the proposer
proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
# Create input tensors
batch_spec = BatchSpec(
seq_lens=seq_lens,
query_lens=seq_lens,
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
target_positions = torch.cat(
[torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
)
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
next_token_ids = torch.randint(
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
)
sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TRITON_ATTN
)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
elif attn_backend == "ROCM_AITER_FA":
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.ROCM_AITER_FA
)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
result = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)
assert result.shape == (batch_size, num_speculative_tokens)
# Create expected tokens based on our token pattern
if num_speculative_tokens == 1:
# Example for num_speculative_tokens=1:
# [[42], [60]]
expected_tokens = torch.tensor(
[[base_token_ids[0]], [base_token_ids[1]]], device=device
)
else:
# Example for num_speculative_tokens=3:
# [[42, 43, 44], [60, 61, 62]]
expected_tokens = torch.zeros(
(batch_size, num_speculative_tokens), dtype=torch.int64, device=device
)
for i in range(batch_size):
for j in range(num_speculative_tokens):
expected_tokens[i, j] = base_token_ids[i] + j
# Verify all tokens match our expectations
assert torch.equal(result, expected_tokens)
@pytest.mark.parametrize(
"spec_token_tree",
[
[(0,)], # A single token
[(0,), (0, 0), (0, 0, 0)], # Chain
[(0,), (1,), (2,)], # Parallel
[(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], # Tree
],
)
def test_propose_tree(spec_token_tree):
# Get GPU device.
device = torch.device(current_platform.device_type)
# Setup test parameters.
batch_size = 2
seq_len_1 = 5
seq_len_2 = 3
total_tokens = seq_len_1 + seq_len_2
vocab_size = 100
seq_lens = [seq_len_1, seq_len_2]
num_speculative_tokens = len(spec_token_tree)
# Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer(
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size
# Helper to create deterministic logits that will produce specific tokens
def create_deterministic_logits(token_ids, k: int):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
for i, token_id in enumerate(token_ids):
# Assign decreasing values to the k, consecutive, tokens.
for j in range(k):
logits[i, token_id + j] = 100.0 - j
return logits
# Mock a model that returns deterministic logits.
base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)
# Skip loading the model and replace it with a mock that returns
# deterministic outputs.
model_mock = mock.MagicMock()
# Mock the model forward calls.
forward_returns = [
(
torch.zeros(total_tokens, hidden_size, device=device),
torch.zeros(total_tokens, hidden_size, device=device),
)
]
for cu_num_drafts in proposer.cu_drafts_per_level:
h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
forward_returns.append((h_logits, h_states))
model_mock.side_effect = forward_returns
# Mock the compute_logits calls.
cu_num_drafts_tensor = torch.tensor(
[0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device
)
logits_returns = []
for level, num_children in enumerate(proposer.child_drafts_per_level):
token_ids = base_token_ids + cu_num_drafts_tensor[level]
level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level]
level_logits = []
for i in range(level_num_drafts // num_children):
level_logits.append(
create_deterministic_logits(token_ids + i * num_children, num_children)
)
logits_returns.append(torch.stack(level_logits, dim=1))
model_mock.compute_logits.side_effect = logits_returns
# Assign the mock to the proposer
proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
# Mock runner for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
# Setup inputs for the proposer.
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
target_positions = torch.cat(
[torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
)
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
next_token_ids = torch.randint(
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
)
batch_spec = BatchSpec(
seq_lens=seq_lens,
query_lens=seq_lens,
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
sampling_metadata = mock.MagicMock()
# Propose draft tokens.
result = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)
assert result.shape == (batch_size, num_speculative_tokens)
# The tokens are expected to be consecutive integers starting
# from the base token IDs.
expected_tokens = base_token_ids[:, None] + torch.arange(
num_speculative_tokens, dtype=torch.int64, device=device
)
# Verify that the draft tokens match our expectations.
assert torch.equal(result, expected_tokens)

View File

@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test whether spec decoding handles the max model length properly."""
import pytest
from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
from vllm.sampling_params import StructuredOutputsParams
_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
]
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(num_speculative_tokens: int):
llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 80,
},
max_model_len=200,
)
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output "
"is truncated due to max length"
)
sampling_params = SamplingParams(
max_tokens=200,
structured_outputs=StructuredOutputsParams(
regex="^" + "a b c d e " * 15 + "$"
),
)
output = llm.generate(_PROMPTS, sampling_params)
for o in output:
assert o.prompt_token_ids is not None
assert (
len(o.prompt_token_ids)
< 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
<= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
assert o.outputs[0].text == "a b c d e " * 15

View File

@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
"""Create an MTP proposer with unified model configuration."""
model_config = ModelConfig(
model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True
)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=mimo_7b_dir,
method="mtp",
num_speculative_tokens=num_speculative_tokens,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach."""
# Setup mocks
mock_model = mock.MagicMock()
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model
# MTP does not have its own embed_tokens or lm_head
# so it should share them with the target model
mock_model.has_own_embed_tokens = False
mock_model.has_own_lm_head = False
target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
target_indexer_layers: dict = {}
all_indexer_layers: dict = {}
mock_get_layers.side_effect = [
target_attn_layers,
target_indexer_layers,
all_attn_layers,
all_indexer_layers,
]
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
mock_get_pp_group.return_value = mock_pp_group
# Create target model
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)
target_model.lm_head = mock.MagicMock()
# Create MTP proposer
proposer = _create_mtp_proposer(num_speculative_tokens=4)
proposer.load_model(target_model)
# Verify MTP-specific behavior:
# Model is loaded
mock_get_model.assert_called_once()
# MTP shares lm_head with target model
assert proposer.model.lm_head == target_model.lm_head
# MTP shares embed_tokens with target model
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
@pytest.mark.parametrize("num_speculative_tokens", [1])
def test_mtp_propose(num_speculative_tokens, monkeypatch):
"""Test that MTP's forward method returns hidden states directly"""
device = torch.device(current_platform.device_type)
batch_size = 2
seq_lens = [5, 3]
total_tokens = sum(seq_lens)
vocab_size = 100
proposer = _create_mtp_proposer(num_speculative_tokens)
hidden_size = proposer.hidden_size
# Mock the MTP model to verify it returns hidden states directly
model_mock = mock.MagicMock()
# MTP returns hidden states directly
if num_speculative_tokens == 1:
model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device)
else:
# Multiple forward passes for multi-token speculation
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
h_states = torch.zeros(total_tokens, hidden_size, device=device)
else:
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append(h_states)
model_mock.side_effect = forward_returns
# Mock compute_logits
def create_deterministic_logits(batch_size, vocab_size, token_offset):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
logits[:, token_offset] = 100.0
return logits
if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = create_deterministic_logits(
batch_size, vocab_size, 42
)
else:
logits_returns = [
create_deterministic_logits(batch_size, vocab_size, 42 + i)
for i in range(num_speculative_tokens)
]
model_mock.compute_logits.side_effect = logits_returns
proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"]
# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(
batch_spec, block_size=16, device=device
)
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
target_positions = torch.cat(
[
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device),
]
)
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
next_token_ids = torch.randint(
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
)
sampling_metadata = mock.MagicMock()
# Setup attention metadata
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
proposer.runner = mock.MagicMock()
proposer.attn_metadata_builder = attn_metadata_builder
# Run propose
result = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)
# Verify the model was called correctly
assert model_mock.called
# Verify output shape
assert result.shape == (batch_size, num_speculative_tokens)

View File

@@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.config import (
ModelConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.v1.spec_decode.ngram_proposer import (
NgramProposer,
_find_longest_matched_ngram_and_propose_tokens,
)
def test_find_longest_matched_ngram_and_propose_tokens():
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
result = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
)
assert len(result) == 0
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
),
np.array([4, 1, 2]),
)
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
),
np.array([4, 1]),
)
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=3
),
np.array([4, 1, 2]),
)
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
),
np.array([4, 1]),
)
tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
),
np.array([4, 1, 2]),
)
# Return on the first match
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
),
np.array([6, 2]),
)
def test_ngram_proposer():
def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m")
return NgramProposer(
vllm_config=VllmConfig(
model_config=model_config,
speculative_config=SpeculativeConfig(
prompt_lookup_min=min_n,
prompt_lookup_max=max_n,
num_speculative_tokens=k,
method="ngram",
),
)
)
# No match.
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# No match for 4-gram.
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# No match for 4-gram but match for 3-gram.
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[4, 1]]))
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
# Match for 2-gram and 3-gram, but not 4-gram.
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
# Multiple 3-gram matched, but always pick the first one.
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[100, 1]]))
# check empty input
token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
# check multibatch input
# first request has 5 tokens and a match
# second request has 3 tokens and no match. Padded with -1 for max len 5
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([3, 1]))
assert np.array_equal(result[1], np.array([]))
# Test non-contiguous indices: requests 0 and 2 need proposals,
# request 1 is in prefill
proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
max_model_len = 20
token_ids_cpu = np.zeros((3, max_model_len), dtype=np.int32)
token_ids_cpu[0, :5] = [1, 2, 3, 1, 2]
token_ids_cpu[1, :3] = [4, 5, 6]
token_ids_cpu[2, :5] = [7, 8, 9, 7, 8]
num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32)
sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill
result = proposer.propose(
sampled_token_ids=sampled_token_ids,
req_ids=["0", "1", "2"],
num_tokens_no_spec=num_tokens_no_spec,
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result) == 3
assert np.array_equal(result[0], [3, 1])
assert len(result[1]) == 0
assert np.array_equal(result[2], [9, 7])
# Verify internal arrays written to correct indices
assert proposer.valid_ngram_num_drafts[0] == 2
assert proposer.valid_ngram_num_drafts[1] == 0
assert proposer.valid_ngram_num_drafts[2] == 2
assert np.array_equal(proposer.valid_ngram_draft[0, :2], [3, 1])
assert np.array_equal(proposer.valid_ngram_draft[2, :2], [9, 7])
# test if 0 threads available: can happen if TP size > CPU count
ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
ngram_proposer.num_numba_thread_available = 0
# set max_model_len to 2 * threshold to ensure multithread is used
num_tokens_threshold = ngram_proposer.num_tokens_threshold
ngram_proposer.max_model_len = 2 * num_tokens_threshold
# using multibatch test
middle_integer = num_tokens_threshold // 2
input_1 = [_ for _ in range(num_tokens_threshold)]
input_1 += [middle_integer, middle_integer + 1]
input_2 = [-1] * len(input_1)
input_2[:3] = [4, 5, 6]
token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3]))
assert np.array_equal(result[1], np.array([]))

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.config import SpeculativeConfig
from vllm.model_executor.models.interfaces import supports_eagle3
from vllm.platforms import current_platform
@pytest.mark.parametrize(
"model_path",
[
pytest.param(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized",
id="llama3-eagle3-speculator",
),
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
id="qwen3-eagle3-speculator",
),
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier",
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="The tests are skipped on rocm platform.",
),
),
],
)
def test_eagle3_speculators_model(
vllm_runner, example_prompts, model_path, monkeypatch
):
"""
Test Eagle3 speculators models properly initialize speculative decoding.
This test verifies:
1. Eagle3 support is detected for the model
2. Speculative config is automatically initialized from embedded config
3. The draft model path is correctly set to the speculators model
4. Speculative tokens count is valid
5. Text generation works with speculative decoding enabled
"""
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
# Verify Eagle3 support is detected
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported, f"Eagle3 should be supported for {model_path}"
vllm_config = vllm_model.llm.llm_engine.vllm_config
assert isinstance(vllm_config.speculative_config, SpeculativeConfig), (
"Speculative config should be initialized for speculators model"
)
spec_config = vllm_config.speculative_config
assert spec_config.num_speculative_tokens > 0, (
f"Expected positive speculative tokens, "
f"got {spec_config.num_speculative_tokens}"
)
assert spec_config.model == model_path, (
f"Draft model should be {model_path}, got {spec_config.model}"
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20)
assert vllm_outputs, f"No outputs generated for speculators model {model_path}"

View File

@@ -0,0 +1,312 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
if not is_flash_attn_varlen_func_available():
pytest.skip(
"This test requires flash_attn_varlen_func, but it's not available.",
allow_module_level=True,
)
class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
def __init__(self):
super().__init__()
def forward(self, x):
return x
def forward_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
seqlen_k: int,
backend: AttentionBackendEnum,
spec_token_tree: str | None = None,
num_spec_tokens: int = 0,
) -> torch.Tensor:
batch_size, q_len, num_heads, dim_per_head = q.shape
num_kv_heads = k.shape[-2]
# Initialize the query and KV sequence lengths.
query_start_loc = q_len * torch.arange(
batch_size + 1, device=q.device, dtype=torch.int32
)
query_lens = torch.diff(query_start_loc)
seq_lens = torch.full(
(batch_size,),
seqlen_k,
device=q.device,
dtype=torch.int32,
)
context_lens = seq_lens - query_lens
max_seq_len = int(seq_lens.max())
max_query_len = q_len
num_actual_tokens = query_start_loc[-1]
softmax_scale = q.shape[-1] ** (-0.5)
layer = MockAttentionLayer()
# Build common metadata.
model_name = "meta-llama/Meta-Llama-3-8B"
builder_cls, impl_cls = try_get_attention_backend(backend)
vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
if spec_token_tree is not None:
# Create speculative config if token tree is specified.
vllm_config.speculative_config = SpeculativeConfig(
target_model_config=vllm_config.model_config,
target_parallel_config=ParallelConfig(),
model=model_name,
method="eagle",
num_speculative_tokens=num_spec_tokens,
speculative_token_tree=spec_token_tree,
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens.cpu(),
_num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
)
# Build attention metadata.
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Initialize the backend implementation.
instance = impl_cls(
num_heads=num_heads,
head_size=dim_per_head,
scale=softmax_scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
)
# Run forward pass and return output.
query = q.view(-1, num_heads, dim_per_head)
key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query)
return instance.forward(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache.clone(),
attn_metadata=attn_metadata,
output=output,
)
def test_tree_attn_correctness() -> None:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
device = "cuda"
tree_attn_masks = {
# Chain.
"[(0,), (0, 0), (0, 0, 0)]": torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
device=device,
dtype=torch.int32,
),
# Tree.
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 0, 1, 0],
[1, 0, 1, 0, 0, 0, 1],
],
device=device,
dtype=torch.int32,
),
}
dim_per_head = 128
num_kv_heads = 2
block_size = 32
max_sequence_length = 8192
randomize_blocks = True
for batch_size in [1, 16, 32]:
for num_heads in [2, 4]:
for sequence_position in [16, 1024, 2048]:
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
# Assert that the number of heads is divisible
# by the number of KV heads.
assert num_heads % num_kv_heads == 0
# Initialize q, k, and v.
tree_size_q = tree_attn_mask.shape[0]
seqlen_k = sequence_position + tree_size_q
q = torch.randn(
(batch_size, tree_size_q, num_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
k = torch.randn(
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
v = torch.randn(
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
# Set up the block table and KV cache for paged KV.
assert max_sequence_length % block_size == 0
max_blocks_per_batch = max_sequence_length // block_size
kv_cache = torch.randn(
(
2,
batch_size * max_blocks_per_batch,
block_size,
num_kv_heads,
dim_per_head,
),
device=q.device,
dtype=torch.bfloat16,
)
num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size)
block_table = torch.zeros(
(batch_size, max_blocks_per_batch),
device=q.device,
dtype=torch.int32,
)
block_ids = torch.arange(
0,
batch_size * num_alloc_blocks_per_batch,
device=q.device,
dtype=torch.int32,
)
if randomize_blocks:
# Randomize the block ids.
block_ids = block_ids[torch.randperm(block_ids.numel())]
block_table[:, :num_alloc_blocks_per_batch] = block_ids.view(
-1, num_alloc_blocks_per_batch
)
# Set up the slot mapping for the input KVs.
tree_positions = sequence_position + torch.arange(
0,
tree_size_q,
device=q.device,
dtype=torch.int64,
).repeat(batch_size, 1)
tree_slot_mapping = _gen_slot_mapping(
tree_positions, block_table, block_size
)
# Compute attention for the tree.
tree_attn_output = forward_attention(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k,
backend=AttentionBackendEnum.TREE_ATTN,
spec_token_tree=spec_token_tree,
num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head)
# Verify that the chain attention output for each
# branch of the tree (computed using FA3) matches
# the tree attention output.
for q_index in range(tree_size_q):
# Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :]
branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0]
q_len = branch_indices.shape[0]
q_branch = q[:, branch_indices]
k_branch = k[:, branch_indices]
v_branch = v[:, branch_indices]
# Setup slot mapping for the branch.
branch_positions = sequence_position + torch.arange(
0,
q_len,
device=q.device,
dtype=torch.int64,
).repeat(batch_size, 1)
branch_slot_mapping = _gen_slot_mapping(
branch_positions, block_table, block_size
)
# Compute flash attention for the branch.
flash_attn_output = forward_attention(
q=q_branch,
k=k_branch,
v=v_branch,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len,
backend=AttentionBackendEnum.FLASH_ATTN,
).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs.
assert torch.allclose(
tree_attn_output[:, branch_indices],
flash_attn_output,
atol=7.81e-3,
), (
f"outputs are not close for "
f"batch_size: {batch_size}, "
f"num_heads: {num_heads}, "
f"sequence_position: {sequence_position}, "
f"tree_attn_mask: {tree_attn_mask}, "
f"q_index: {q_index}."
)
def _gen_slot_mapping(
positions: torch.Tensor, block_table: torch.Tensor, block_size: int
):
block_indices = positions // block_size
blocks = block_table.gather(dim=1, index=block_indices)
return (blocks * block_size + positions % block_size).view(-1)