Sync from v0.13
This commit is contained in:
742
tests/v1/spec_decode/test_eagle.py
Normal file
742
tests/v1/spec_decode/test_eagle.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def _create_proposer(
|
||||
method: str,
|
||||
num_speculative_tokens: int,
|
||||
speculative_token_tree: list[tuple[int, ...]] | None = None,
|
||||
) -> EagleProposer:
|
||||
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
|
||||
|
||||
# Choose model directory based on method
|
||||
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
|
||||
|
||||
spec_token_tree_str = None
|
||||
if speculative_token_tree is not None:
|
||||
assert num_speculative_tokens == len(speculative_token_tree)
|
||||
spec_token_tree_str = str(speculative_token_tree)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=draft_model_dir,
|
||||
method=method,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
speculative_token_tree=spec_token_tree_str,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
)
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||
|
||||
|
||||
def test_prepare_next_token_ids():
|
||||
"""
|
||||
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
|
||||
Each will produce a device tensor of next_token_ids, taking as input
|
||||
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
|
||||
or the CPU python list[list[int]] with the rejected tokens removed.
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_requests = 4
|
||||
num_speculative_tokens = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
query_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
|
||||
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
|
||||
mock_requests = {}
|
||||
for req_id in req_ids:
|
||||
mock_request = mock.MagicMock(spec=CachedRequestState)
|
||||
# Each request will have a backup next token id of 10, 20, 30, 40
|
||||
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
|
||||
mock_request.num_computed_tokens = 0
|
||||
mock_requests[req_id] = mock_request
|
||||
|
||||
# explicitly discard the last request
|
||||
discarded_req_mask = torch.tensor(
|
||||
[False, False, False, True], dtype=torch.bool, device=device
|
||||
)
|
||||
sampled_token_ids = [
|
||||
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
|
||||
[0, 1, 2, 3, 4], # all accepted, "4" sampled
|
||||
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
|
||||
[0, 1, 2, -1, -1], # explicitly discarded, sampling should be ignored
|
||||
]
|
||||
sampled_token_ids_tensor = torch.tensor(
|
||||
sampled_token_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
|
||||
for i in range(len(sampled_token_ids_cpu)):
|
||||
if discarded_req_mask[i]:
|
||||
sampled_token_ids_cpu[i] = []
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(
|
||||
expected_next_token_ids_cpu, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids_cpu,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
mock_num_scheduled_tokens,
|
||||
)
|
||||
|
||||
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
[2, 5, 0, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
next_token_ids_from_padded, valid_sampled_tokens_count = (
|
||||
proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids_tensor,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
discarded_req_mask,
|
||||
)
|
||||
)
|
||||
|
||||
assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor)
|
||||
assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
|
||||
|
||||
|
||||
def test_prepare_inputs():
|
||||
"""
|
||||
cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
num_rejected_tokens: [n1, n2, n3]
|
||||
num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||||
cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
token_indices: [0, 1, ..., a - n1 - 1,
|
||||
a, a + 1, ..., a + b - n2 - 1,
|
||||
a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# q1 = 4, q2 = 7, q3 = 5
|
||||
# n1 = 1, n2 = 3, n3 = 2
|
||||
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[4, 7, 5],
|
||||
query_lens=[4, 7, 5],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
|
||||
# from the previous iteration, and the last token is the bonus token sampled
|
||||
# from the base model.
|
||||
num_draft_tokens = [3, 6, 4] # one less than query_lens
|
||||
# num rejected tokens is [1, 3, 2]
|
||||
ACCEPT_TOKEN = 0
|
||||
BONUS_TOKEN = 1
|
||||
REJECT_TOKEN = -1
|
||||
sampled_token_ids = [
|
||||
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
|
||||
[
|
||||
ACCEPT_TOKEN,
|
||||
ACCEPT_TOKEN,
|
||||
ACCEPT_TOKEN,
|
||||
REJECT_TOKEN,
|
||||
REJECT_TOKEN,
|
||||
REJECT_TOKEN,
|
||||
BONUS_TOKEN,
|
||||
],
|
||||
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
|
||||
]
|
||||
sampled_token_ids = [
|
||||
[i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids
|
||||
]
|
||||
|
||||
# Expected calculations:
|
||||
# query_len_per_req = [4, 7, 5]
|
||||
# num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens)
|
||||
# Expected cumulative counts: [0, 3, 7, 10]
|
||||
expected_cu_num_tokens = torch.tensor(
|
||||
[0, 3, 7, 10], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Expected token indices (mapped from original positions):
|
||||
# First request: indices 0, 1, 2 (keeping first 3 from positions 0-3)
|
||||
# Second request: indices 4, 5, 6, 7 (keeping first 4 from positions 4-10)
|
||||
# Third request: indices 11, 12, 13 (keeping first 3 from positions 11-15)
|
||||
expected_token_indices = torch.tensor(
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # First request: 3 tokens (4-1)
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7, # Second request: 4 tokens (7-3)
|
||||
11,
|
||||
12,
|
||||
13, # Third request: 3 tokens (5-2)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
proposer = _create_proposer("eagle", 1)
|
||||
|
||||
updated_metadata, token_indices = proposer.prepare_inputs(
|
||||
common_attn_metadata, sampled_token_ids, num_draft_tokens
|
||||
)
|
||||
|
||||
assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens)
|
||||
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
|
||||
|
||||
def test_prepare_inputs_padded():
|
||||
"""
|
||||
Input scenario is 3 requests with num_speculative_tokens == 2 and:
|
||||
- Request 1: query_len = 3, rejected = 1
|
||||
- Request 2: query_len = 3, rejected = 0
|
||||
- Request 3: query_len = 3, rejected = 2
|
||||
|
||||
Expected outputs:
|
||||
token_indices_to_sample: [1, 5, 6]
|
||||
Reason: After accounting for rejections, these are the valid token positions
|
||||
from the original indices to sample from.
|
||||
"""
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
expected_token_indices_to_sample = torch.tensor(
|
||||
[1, 5, 6], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
num_speculative_tokens = 2
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[3, 3, 3],
|
||||
query_lens=[3, 3, 3],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
|
||||
expected_query_start_loc = torch.tensor(
|
||||
[0, 3, 6, 9], dtype=torch.int32, device=device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids=[[0] * num_speculative_tokens] * 3,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# num_rejected_tokens = [1, 0, 2]
|
||||
# num_draft_tokens = [2, 2, 2]
|
||||
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
|
||||
valid_sampled_tokens_count = torch.tensor(
|
||||
[2, 3, 1], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
assert output_metadata.max_query_len == 3
|
||||
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
|
||||
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("pp_size", [1, 2])
|
||||
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
||||
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
|
||||
def test_load_model(
|
||||
mock_get_model,
|
||||
mock_get_layers,
|
||||
mock_get_pp_group,
|
||||
method,
|
||||
attn_backend,
|
||||
pp_size,
|
||||
use_distinct_embed_tokens,
|
||||
use_distinct_lm_head,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Setup draft model mock
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model.model = mock.MagicMock()
|
||||
mock_model.has_own_embed_tokens = use_distinct_embed_tokens
|
||||
if use_distinct_embed_tokens:
|
||||
mock_model.model.embed_tokens = mock.MagicMock()
|
||||
mock_model.has_own_lm_head = use_distinct_lm_head
|
||||
if use_distinct_lm_head:
|
||||
mock_model.lm_head = mock.MagicMock()
|
||||
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
# Setup mocks for attention layers
|
||||
target_attn_layers = {
|
||||
"target_attn_1": mock.MagicMock(),
|
||||
"target_attn_2": mock.MagicMock(),
|
||||
}
|
||||
target_indx_layers: dict[str, mock.MagicMock] = {}
|
||||
# Draft model has one extra attention layer compared to target model
|
||||
all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()}
|
||||
|
||||
all_indx_layers: dict[str, mock.MagicMock] = {}
|
||||
|
||||
# Make mock_get_layers return different values for each call
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers,
|
||||
target_indx_layers,
|
||||
all_attn_layers,
|
||||
all_indx_layers,
|
||||
]
|
||||
|
||||
# Setup mock for pp group to return the appropriate value for world size
|
||||
mock_pp_group = mock.MagicMock()
|
||||
mock_pp_group.world_size = pp_size
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Set up the target model mock with a custom class so that
|
||||
# isinstance() checks match the expected type.
|
||||
class _TargetModelStub(LlamaForCausalLM):
|
||||
model: mock.MagicMock
|
||||
lm_head: mock.MagicMock
|
||||
|
||||
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
||||
target_model.model = mock.MagicMock()
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
target_model.model.embed_tokens = mock.MagicMock()
|
||||
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
|
||||
assert not isinstance(target_model, SupportsMultiModal)
|
||||
|
||||
# Create proposer using the helper function
|
||||
proposer = _create_proposer(method, num_speculative_tokens=8)
|
||||
|
||||
# Call the method under test
|
||||
proposer.load_model(target_model)
|
||||
|
||||
# Verify common interactions
|
||||
mock_get_model.assert_called_once()
|
||||
|
||||
# Verify that the lm head is set correctly
|
||||
if use_distinct_lm_head:
|
||||
assert proposer.model.lm_head is not target_model.lm_head
|
||||
else:
|
||||
assert proposer.model.lm_head is target_model.lm_head
|
||||
|
||||
# Verify that the embed tokens are set correctly
|
||||
# If pp_size is > 1, the embed tokens should be distinct
|
||||
if pp_size > 1 or use_distinct_embed_tokens:
|
||||
assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
|
||||
else:
|
||||
assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "TREE_ATTN":
|
||||
pytest.skip(
|
||||
"TREE_ATTN is tested separately in test_propose_tree"
|
||||
"because it requires special input mocking."
|
||||
)
|
||||
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Use GPU device
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# Setup test parameters
|
||||
batch_size = 2
|
||||
seq_len_1 = 5
|
||||
seq_len_2 = 3
|
||||
total_tokens = seq_len_1 + seq_len_2
|
||||
vocab_size = 100
|
||||
seq_lens = [seq_len_1, seq_len_2]
|
||||
|
||||
# Create proposer first so we can use its actual hidden_size
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
# Get the hidden_size from the proposer to ensure consistency
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Helper to create deterministic logits that will produce specific tokens
|
||||
def create_deterministic_logits(token_ids):
|
||||
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||
for i, token_id in enumerate(token_ids):
|
||||
logits[i, token_id] = 100.0
|
||||
return logits
|
||||
|
||||
# We mock a model that returns deterministic logits
|
||||
# Sequence 1: 42, 43, 44, ...
|
||||
# Sequence 2: 60, 61, 62, ...
|
||||
base_token_ids = [42, 60]
|
||||
|
||||
# Skip loading the model and replace it with a mock directly
|
||||
# Create the mock model with deterministic outputs
|
||||
model_mock = mock.MagicMock()
|
||||
|
||||
# Setup for model forward calls
|
||||
forward_returns = []
|
||||
for i in range(num_speculative_tokens):
|
||||
if i == 0:
|
||||
# First call uses all tokens
|
||||
h_logits = torch.zeros(total_tokens, hidden_size, device=device)
|
||||
h_states = torch.zeros(total_tokens, hidden_size, device=device)
|
||||
else:
|
||||
# Subsequent calls use batch_size tokens
|
||||
h_logits = torch.zeros(batch_size, hidden_size, device=device)
|
||||
h_states = torch.zeros(batch_size, hidden_size, device=device)
|
||||
forward_returns.append((h_logits, h_states))
|
||||
|
||||
# For single token case, we only need the first item;
|
||||
# for multi-token, we need the sequence
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.return_value = forward_returns[0]
|
||||
else:
|
||||
model_mock.side_effect = forward_returns
|
||||
|
||||
# Setup for compute_logits calls
|
||||
logits_returns = []
|
||||
for i in range(num_speculative_tokens):
|
||||
# For each call, increment the base token IDs
|
||||
current_tokens = [base_id + i for base_id in base_token_ids]
|
||||
logits_returns.append(create_deterministic_logits(current_tokens))
|
||||
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.compute_logits.return_value = logits_returns[0]
|
||||
else:
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
# Assign the mock to the proposer
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Create input tensors
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
|
||||
target_positions = torch.cat(
|
||||
[torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
|
||||
)
|
||||
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
|
||||
next_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
if attn_backend == "FLASH_ATTN":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.FLASH_ATTN
|
||||
)
|
||||
elif attn_backend == "TRITON_ATTN":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.TRITON_ATTN
|
||||
)
|
||||
elif attn_backend == "TREE_ATTN":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.TREE_ATTN
|
||||
)
|
||||
elif attn_backend == "ROCM_AITER_FA":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.ROCM_AITER_FA
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder.return_value = attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder
|
||||
)
|
||||
|
||||
result = proposer.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
|
||||
# Create expected tokens based on our token pattern
|
||||
if num_speculative_tokens == 1:
|
||||
# Example for num_speculative_tokens=1:
|
||||
# [[42], [60]]
|
||||
expected_tokens = torch.tensor(
|
||||
[[base_token_ids[0]], [base_token_ids[1]]], device=device
|
||||
)
|
||||
else:
|
||||
# Example for num_speculative_tokens=3:
|
||||
# [[42, 43, 44], [60, 61, 62]]
|
||||
expected_tokens = torch.zeros(
|
||||
(batch_size, num_speculative_tokens), dtype=torch.int64, device=device
|
||||
)
|
||||
for i in range(batch_size):
|
||||
for j in range(num_speculative_tokens):
|
||||
expected_tokens[i, j] = base_token_ids[i] + j
|
||||
|
||||
# Verify all tokens match our expectations
|
||||
assert torch.equal(result, expected_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec_token_tree",
|
||||
[
|
||||
[(0,)], # A single token
|
||||
[(0,), (0, 0), (0, 0, 0)], # Chain
|
||||
[(0,), (1,), (2,)], # Parallel
|
||||
[(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], # Tree
|
||||
],
|
||||
)
|
||||
def test_propose_tree(spec_token_tree):
|
||||
# Get GPU device.
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# Setup test parameters.
|
||||
batch_size = 2
|
||||
seq_len_1 = 5
|
||||
seq_len_2 = 3
|
||||
total_tokens = seq_len_1 + seq_len_2
|
||||
vocab_size = 100
|
||||
seq_lens = [seq_len_1, seq_len_2]
|
||||
num_speculative_tokens = len(spec_token_tree)
|
||||
|
||||
# Create proposer first so we can use its actual hidden_size.
|
||||
proposer = _create_proposer(
|
||||
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
|
||||
)
|
||||
# Get the hidden_size from the proposer to ensure consistency.
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Helper to create deterministic logits that will produce specific tokens
|
||||
def create_deterministic_logits(token_ids, k: int):
|
||||
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||
for i, token_id in enumerate(token_ids):
|
||||
# Assign decreasing values to the k, consecutive, tokens.
|
||||
for j in range(k):
|
||||
logits[i, token_id + j] = 100.0 - j
|
||||
return logits
|
||||
|
||||
# Mock a model that returns deterministic logits.
|
||||
base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)
|
||||
|
||||
# Skip loading the model and replace it with a mock that returns
|
||||
# deterministic outputs.
|
||||
model_mock = mock.MagicMock()
|
||||
|
||||
# Mock the model forward calls.
|
||||
forward_returns = [
|
||||
(
|
||||
torch.zeros(total_tokens, hidden_size, device=device),
|
||||
torch.zeros(total_tokens, hidden_size, device=device),
|
||||
)
|
||||
]
|
||||
for cu_num_drafts in proposer.cu_drafts_per_level:
|
||||
h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
|
||||
h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
|
||||
forward_returns.append((h_logits, h_states))
|
||||
model_mock.side_effect = forward_returns
|
||||
|
||||
# Mock the compute_logits calls.
|
||||
cu_num_drafts_tensor = torch.tensor(
|
||||
[0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device
|
||||
)
|
||||
logits_returns = []
|
||||
for level, num_children in enumerate(proposer.child_drafts_per_level):
|
||||
token_ids = base_token_ids + cu_num_drafts_tensor[level]
|
||||
level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level]
|
||||
level_logits = []
|
||||
for i in range(level_num_drafts // num_children):
|
||||
level_logits.append(
|
||||
create_deterministic_logits(token_ids + i * num_children, num_children)
|
||||
)
|
||||
logits_returns.append(torch.stack(level_logits, dim=1))
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
# Assign the mock to the proposer
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Get the tree attention metadata builder.
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.TREE_ATTN
|
||||
)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
|
||||
proposer.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder.return_value = attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder
|
||||
)
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
|
||||
target_positions = torch.cat(
|
||||
[torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
|
||||
)
|
||||
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
|
||||
next_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Propose draft tokens.
|
||||
result = proposer.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
|
||||
# The tokens are expected to be consecutive integers starting
|
||||
# from the base token IDs.
|
||||
expected_tokens = base_token_ids[:, None] + torch.arange(
|
||||
num_speculative_tokens, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Verify that the draft tokens match our expectations.
|
||||
assert torch.equal(result, expected_tokens)
|
||||
90
tests/v1/spec_decode/test_max_len.py
Normal file
90
tests/v1/spec_decode/test_max_len.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test whether spec decoding handles the max model length properly."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
_PROMPTS = [
|
||||
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
|
||||
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
|
||||
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
def test_ngram_max_len(num_speculative_tokens: int):
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
max_model_len=100,
|
||||
enforce_eager=True, # For faster initialization.
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
||||
llm.generate(_PROMPTS, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_max_len(
|
||||
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
enforce_eager=True, # For faster initialization.
|
||||
speculative_config={
|
||||
"method": "eagle",
|
||||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
"max_model_len": 80,
|
||||
},
|
||||
max_model_len=200,
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
|
||||
outputs = llm.generate(_PROMPTS, sampling_params)
|
||||
for o in outputs:
|
||||
assert o.outputs[0].finish_reason == "length", (
|
||||
"This test is only meaningful if the output "
|
||||
"is truncated due to max length"
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=200,
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
regex="^" + "a b c d e " * 15 + "$"
|
||||
),
|
||||
)
|
||||
output = llm.generate(_PROMPTS, sampling_params)
|
||||
for o in output:
|
||||
assert o.prompt_token_ids is not None
|
||||
assert (
|
||||
len(o.prompt_token_ids)
|
||||
< 80
|
||||
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
|
||||
<= 200
|
||||
), (
|
||||
"This test is only meaningful if the output "
|
||||
"is longer than the eagle max length"
|
||||
)
|
||||
assert o.outputs[0].text == "a b c d e " * 15
|
||||
215
tests/v1/spec_decode/test_mtp.py
Normal file
215
tests/v1/spec_decode/test_mtp.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
|
||||
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
|
||||
|
||||
|
||||
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
||||
"""Create an MTP proposer with unified model configuration."""
|
||||
model_config = ModelConfig(
|
||||
model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True
|
||||
)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=mimo_7b_dir,
|
||||
method="mtp",
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
)
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||
|
||||
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
|
||||
def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group):
|
||||
"""Test MTP-specific model loading with unified model approach."""
|
||||
|
||||
# Setup mocks
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
mock_get_model.return_value = mock_model
|
||||
# MTP does not have its own embed_tokens or lm_head
|
||||
# so it should share them with the target model
|
||||
mock_model.has_own_embed_tokens = False
|
||||
mock_model.has_own_lm_head = False
|
||||
|
||||
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
||||
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
||||
target_indexer_layers: dict = {}
|
||||
all_indexer_layers: dict = {}
|
||||
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers,
|
||||
target_indexer_layers,
|
||||
all_attn_layers,
|
||||
all_indexer_layers,
|
||||
]
|
||||
|
||||
mock_pp_group = mock.MagicMock()
|
||||
mock_pp_group.world_size = 1
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Create target model
|
||||
class _TargetModelStub(LlamaForCausalLM):
|
||||
model: mock.MagicMock
|
||||
lm_head: mock.MagicMock
|
||||
|
||||
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
||||
target_model.model = mock.MagicMock()
|
||||
target_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create MTP proposer
|
||||
proposer = _create_mtp_proposer(num_speculative_tokens=4)
|
||||
proposer.load_model(target_model)
|
||||
|
||||
# Verify MTP-specific behavior:
|
||||
# Model is loaded
|
||||
mock_get_model.assert_called_once()
|
||||
# MTP shares lm_head with target model
|
||||
assert proposer.model.lm_head == target_model.lm_head
|
||||
# MTP shares embed_tokens with target model
|
||||
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||
def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
"""Test that MTP's forward method returns hidden states directly"""
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
batch_size = 2
|
||||
seq_lens = [5, 3]
|
||||
total_tokens = sum(seq_lens)
|
||||
vocab_size = 100
|
||||
|
||||
proposer = _create_mtp_proposer(num_speculative_tokens)
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Mock the MTP model to verify it returns hidden states directly
|
||||
model_mock = mock.MagicMock()
|
||||
|
||||
# MTP returns hidden states directly
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device)
|
||||
else:
|
||||
# Multiple forward passes for multi-token speculation
|
||||
forward_returns = []
|
||||
for i in range(num_speculative_tokens):
|
||||
if i == 0:
|
||||
h_states = torch.zeros(total_tokens, hidden_size, device=device)
|
||||
else:
|
||||
h_states = torch.zeros(batch_size, hidden_size, device=device)
|
||||
forward_returns.append(h_states)
|
||||
model_mock.side_effect = forward_returns
|
||||
|
||||
# Mock compute_logits
|
||||
def create_deterministic_logits(batch_size, vocab_size, token_offset):
|
||||
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||
logits[:, token_offset] = 100.0
|
||||
return logits
|
||||
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.compute_logits.return_value = create_deterministic_logits(
|
||||
batch_size, vocab_size, 42
|
||||
)
|
||||
else:
|
||||
logits_returns = [
|
||||
create_deterministic_logits(batch_size, vocab_size, 42 + i)
|
||||
for i in range(num_speculative_tokens)
|
||||
]
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
proposer.model = model_mock
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Prepare inputs
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, block_size=16, device=device
|
||||
)
|
||||
|
||||
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
|
||||
target_positions = torch.cat(
|
||||
[
|
||||
torch.arange(seq_lens[0], device=device),
|
||||
torch.arange(seq_lens[1], device=device),
|
||||
]
|
||||
)
|
||||
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
|
||||
next_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Setup attention metadata
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.FLASH_ATTN
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.attn_metadata_builder = attn_metadata_builder
|
||||
|
||||
# Run propose
|
||||
result = proposer.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# Verify the model was called correctly
|
||||
assert model_mock.called
|
||||
# Verify output shape
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
224
tests/v1/spec_decode/test_ngram.py
Normal file
224
tests/v1/spec_decode/test_ngram.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
|
||||
from vllm.config import (
|
||||
ModelConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.v1.spec_decode.ngram_proposer import (
|
||||
NgramProposer,
|
||||
_find_longest_matched_ngram_and_propose_tokens,
|
||||
)
|
||||
|
||||
|
||||
def test_find_longest_matched_ngram_and_propose_tokens():
|
||||
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
||||
result = _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
|
||||
),
|
||||
np.array([4, 1, 2]),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
|
||||
),
|
||||
np.array([4, 1]),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=3
|
||||
),
|
||||
np.array([4, 1, 2]),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
|
||||
),
|
||||
np.array([4, 1]),
|
||||
)
|
||||
|
||||
tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
|
||||
),
|
||||
np.array([4, 1, 2]),
|
||||
)
|
||||
# Return on the first match
|
||||
np.testing.assert_array_equal(
|
||||
_find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
|
||||
),
|
||||
np.array([6, 2]),
|
||||
)
|
||||
|
||||
|
||||
def test_ngram_proposer():
|
||||
def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
|
||||
# Dummy model config. Just to set max_model_len.
|
||||
model_config = ModelConfig(model="facebook/opt-125m")
|
||||
return NgramProposer(
|
||||
vllm_config=VllmConfig(
|
||||
model_config=model_config,
|
||||
speculative_config=SpeculativeConfig(
|
||||
prompt_lookup_min=min_n,
|
||||
prompt_lookup_max=max_n,
|
||||
num_speculative_tokens=k,
|
||||
method="ngram",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# No match.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
# No match for 4-gram.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
# No match for 4-gram but match for 3-gram.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[4, 1]]))
|
||||
|
||||
# Match for both 4-gram and 3-gram.
|
||||
# In this case, the proposer should return the 4-gram match.
|
||||
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
|
||||
|
||||
# Match for 2-gram and 3-gram, but not 4-gram.
|
||||
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
|
||||
|
||||
# Multiple 3-gram matched, but always pick the first one.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[100, 1]]))
|
||||
|
||||
# check empty input
|
||||
token_ids_cpu = np.array([[]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
# check multibatch input
|
||||
# first request has 5 tokens and a match
|
||||
# second request has 3 tokens and no match. Padded with -1 for max len 5
|
||||
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0], [1]],
|
||||
req_ids=["0", "1"],
|
||||
num_tokens_no_spec=np.array([5, 3]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 2
|
||||
assert np.array_equal(result[0], np.array([3, 1]))
|
||||
assert np.array_equal(result[1], np.array([]))
|
||||
|
||||
# Test non-contiguous indices: requests 0 and 2 need proposals,
|
||||
# request 1 is in prefill
|
||||
proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
|
||||
max_model_len = 20
|
||||
token_ids_cpu = np.zeros((3, max_model_len), dtype=np.int32)
|
||||
token_ids_cpu[0, :5] = [1, 2, 3, 1, 2]
|
||||
token_ids_cpu[1, :3] = [4, 5, 6]
|
||||
token_ids_cpu[2, :5] = [7, 8, 9, 7, 8]
|
||||
num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32)
|
||||
sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill
|
||||
result = proposer.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
req_ids=["0", "1", "2"],
|
||||
num_tokens_no_spec=num_tokens_no_spec,
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result) == 3
|
||||
assert np.array_equal(result[0], [3, 1])
|
||||
assert len(result[1]) == 0
|
||||
assert np.array_equal(result[2], [9, 7])
|
||||
# Verify internal arrays written to correct indices
|
||||
assert proposer.valid_ngram_num_drafts[0] == 2
|
||||
assert proposer.valid_ngram_num_drafts[1] == 0
|
||||
assert proposer.valid_ngram_num_drafts[2] == 2
|
||||
assert np.array_equal(proposer.valid_ngram_draft[0, :2], [3, 1])
|
||||
assert np.array_equal(proposer.valid_ngram_draft[2, :2], [9, 7])
|
||||
|
||||
# test if 0 threads available: can happen if TP size > CPU count
|
||||
ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
|
||||
ngram_proposer.num_numba_thread_available = 0
|
||||
# set max_model_len to 2 * threshold to ensure multithread is used
|
||||
num_tokens_threshold = ngram_proposer.num_tokens_threshold
|
||||
ngram_proposer.max_model_len = 2 * num_tokens_threshold
|
||||
# using multibatch test
|
||||
middle_integer = num_tokens_threshold // 2
|
||||
input_1 = [_ for _ in range(num_tokens_threshold)]
|
||||
input_1 += [middle_integer, middle_integer + 1]
|
||||
input_2 = [-1] * len(input_1)
|
||||
input_2[:3] = [4, 5, 6]
|
||||
token_ids_cpu = np.array([input_1, input_2])
|
||||
result = ngram_proposer.propose(
|
||||
sampled_token_ids=[[0], [1]],
|
||||
req_ids=["0", "1"],
|
||||
num_tokens_no_spec=np.array([len(input_1), 3]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 2
|
||||
assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3]))
|
||||
assert np.array_equal(result[1], np.array([]))
|
||||
70
tests/v1/spec_decode/test_speculators_eagle3.py
Normal file
70
tests/v1/spec_decode/test_speculators_eagle3.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.model_executor.models.interfaces import supports_eagle3
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
pytest.param(
|
||||
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized",
|
||||
id="llama3-eagle3-speculator",
|
||||
),
|
||||
pytest.param(
|
||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
|
||||
id="qwen3-eagle3-speculator",
|
||||
),
|
||||
pytest.param(
|
||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
||||
id="qwen3-eagle3-speculator-w4a16-verifier",
|
||||
marks=pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="The tests are skipped on rocm platform.",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eagle3_speculators_model(
|
||||
vllm_runner, example_prompts, model_path, monkeypatch
|
||||
):
|
||||
"""
|
||||
Test Eagle3 speculators models properly initialize speculative decoding.
|
||||
|
||||
This test verifies:
|
||||
1. Eagle3 support is detected for the model
|
||||
2. Speculative config is automatically initialized from embedded config
|
||||
3. The draft model path is correctly set to the speculators model
|
||||
4. Speculative tokens count is valid
|
||||
5. Text generation works with speculative decoding enabled
|
||||
"""
|
||||
# Set environment variable for V1 engine serialization
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||
# Verify Eagle3 support is detected
|
||||
eagle3_supported = vllm_model.apply_model(supports_eagle3)
|
||||
assert eagle3_supported, f"Eagle3 should be supported for {model_path}"
|
||||
|
||||
vllm_config = vllm_model.llm.llm_engine.vllm_config
|
||||
|
||||
assert isinstance(vllm_config.speculative_config, SpeculativeConfig), (
|
||||
"Speculative config should be initialized for speculators model"
|
||||
)
|
||||
|
||||
spec_config = vllm_config.speculative_config
|
||||
assert spec_config.num_speculative_tokens > 0, (
|
||||
f"Expected positive speculative tokens, "
|
||||
f"got {spec_config.num_speculative_tokens}"
|
||||
)
|
||||
|
||||
assert spec_config.model == model_path, (
|
||||
f"Draft model should be {model_path}, got {spec_config.model}"
|
||||
)
|
||||
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20)
|
||||
assert vllm_outputs, f"No outputs generated for speculators model {model_path}"
|
||||
312
tests/v1/spec_decode/test_tree_attention.py
Normal file
312
tests/v1/spec_decode/test_tree_attention.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
if not is_flash_attn_varlen_func_available():
|
||||
pytest.skip(
|
||||
"This test requires flash_attn_varlen_func, but it's not available.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
class MockAttentionLayer(torch.nn.Module):
|
||||
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def forward_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
seqlen_k: int,
|
||||
backend: AttentionBackendEnum,
|
||||
spec_token_tree: str | None = None,
|
||||
num_spec_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, q_len, num_heads, dim_per_head = q.shape
|
||||
num_kv_heads = k.shape[-2]
|
||||
# Initialize the query and KV sequence lengths.
|
||||
query_start_loc = q_len * torch.arange(
|
||||
batch_size + 1, device=q.device, dtype=torch.int32
|
||||
)
|
||||
query_lens = torch.diff(query_start_loc)
|
||||
seq_lens = torch.full(
|
||||
(batch_size,),
|
||||
seqlen_k,
|
||||
device=q.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
context_lens = seq_lens - query_lens
|
||||
max_seq_len = int(seq_lens.max())
|
||||
max_query_len = q_len
|
||||
num_actual_tokens = query_start_loc[-1]
|
||||
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
layer = MockAttentionLayer()
|
||||
|
||||
# Build common metadata.
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
builder_cls, impl_cls = try_get_attention_backend(backend)
|
||||
vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
|
||||
if spec_token_tree is not None:
|
||||
# Create speculative config if token tree is specified.
|
||||
vllm_config.speculative_config = SpeculativeConfig(
|
||||
target_model_config=vllm_config.model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=model_name,
|
||||
method="eagle",
|
||||
num_speculative_tokens=num_spec_tokens,
|
||||
speculative_token_tree=spec_token_tree,
|
||||
)
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc.cpu(),
|
||||
seq_lens=seq_lens,
|
||||
_seq_lens_cpu=seq_lens.cpu(),
|
||||
_num_computed_tokens_cpu=context_lens.cpu(),
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
# Build attention metadata.
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
# Initialize the backend implementation.
|
||||
instance = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=dim_per_head,
|
||||
scale=softmax_scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
|
||||
# Run forward pass and return output.
|
||||
query = q.view(-1, num_heads, dim_per_head)
|
||||
key = k.view(-1, num_kv_heads, dim_per_head)
|
||||
value = v.view(-1, num_kv_heads, dim_per_head)
|
||||
output = torch.empty_like(query)
|
||||
return instance.forward(
|
||||
layer=layer,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache.clone(),
|
||||
attn_metadata=attn_metadata,
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
def test_tree_attn_correctness() -> None:
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
|
||||
device = "cuda"
|
||||
tree_attn_masks = {
|
||||
# Chain.
|
||||
"[(0,), (0, 0), (0, 0, 0)]": torch.tensor(
|
||||
[
|
||||
[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 0],
|
||||
[1, 1, 1, 1],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
# Tree.
|
||||
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor(
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0],
|
||||
[1, 0, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 1, 0, 0, 0],
|
||||
[1, 1, 0, 0, 1, 0, 0],
|
||||
[1, 0, 1, 0, 0, 1, 0],
|
||||
[1, 0, 1, 0, 0, 0, 1],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
}
|
||||
|
||||
dim_per_head = 128
|
||||
num_kv_heads = 2
|
||||
block_size = 32
|
||||
max_sequence_length = 8192
|
||||
randomize_blocks = True
|
||||
for batch_size in [1, 16, 32]:
|
||||
for num_heads in [2, 4]:
|
||||
for sequence_position in [16, 1024, 2048]:
|
||||
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
|
||||
# Assert that the number of heads is divisible
|
||||
# by the number of KV heads.
|
||||
assert num_heads % num_kv_heads == 0
|
||||
|
||||
# Initialize q, k, and v.
|
||||
tree_size_q = tree_attn_mask.shape[0]
|
||||
seqlen_k = sequence_position + tree_size_q
|
||||
q = torch.randn(
|
||||
(batch_size, tree_size_q, num_heads, dim_per_head),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
k = torch.randn(
|
||||
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
v = torch.randn(
|
||||
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Set up the block table and KV cache for paged KV.
|
||||
assert max_sequence_length % block_size == 0
|
||||
max_blocks_per_batch = max_sequence_length // block_size
|
||||
kv_cache = torch.randn(
|
||||
(
|
||||
2,
|
||||
batch_size * max_blocks_per_batch,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
dim_per_head,
|
||||
),
|
||||
device=q.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size)
|
||||
block_table = torch.zeros(
|
||||
(batch_size, max_blocks_per_batch),
|
||||
device=q.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
block_ids = torch.arange(
|
||||
0,
|
||||
batch_size * num_alloc_blocks_per_batch,
|
||||
device=q.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
if randomize_blocks:
|
||||
# Randomize the block ids.
|
||||
block_ids = block_ids[torch.randperm(block_ids.numel())]
|
||||
block_table[:, :num_alloc_blocks_per_batch] = block_ids.view(
|
||||
-1, num_alloc_blocks_per_batch
|
||||
)
|
||||
|
||||
# Set up the slot mapping for the input KVs.
|
||||
tree_positions = sequence_position + torch.arange(
|
||||
0,
|
||||
tree_size_q,
|
||||
device=q.device,
|
||||
dtype=torch.int64,
|
||||
).repeat(batch_size, 1)
|
||||
tree_slot_mapping = _gen_slot_mapping(
|
||||
tree_positions, block_table, block_size
|
||||
)
|
||||
|
||||
# Compute attention for the tree.
|
||||
tree_attn_output = forward_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
kv_cache=kv_cache,
|
||||
block_table=block_table,
|
||||
slot_mapping=tree_slot_mapping,
|
||||
seqlen_k=seqlen_k,
|
||||
backend=AttentionBackendEnum.TREE_ATTN,
|
||||
spec_token_tree=spec_token_tree,
|
||||
num_spec_tokens=tree_size_q - 1,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
|
||||
# Verify that the chain attention output for each
|
||||
# branch of the tree (computed using FA3) matches
|
||||
# the tree attention output.
|
||||
for q_index in range(tree_size_q):
|
||||
# Get the q, k, and v for the branch.
|
||||
branch_mask = tree_attn_mask[q_index, :]
|
||||
branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0]
|
||||
q_len = branch_indices.shape[0]
|
||||
q_branch = q[:, branch_indices]
|
||||
k_branch = k[:, branch_indices]
|
||||
v_branch = v[:, branch_indices]
|
||||
|
||||
# Setup slot mapping for the branch.
|
||||
branch_positions = sequence_position + torch.arange(
|
||||
0,
|
||||
q_len,
|
||||
device=q.device,
|
||||
dtype=torch.int64,
|
||||
).repeat(batch_size, 1)
|
||||
branch_slot_mapping = _gen_slot_mapping(
|
||||
branch_positions, block_table, block_size
|
||||
)
|
||||
|
||||
# Compute flash attention for the branch.
|
||||
flash_attn_output = forward_attention(
|
||||
q=q_branch,
|
||||
k=k_branch,
|
||||
v=v_branch,
|
||||
kv_cache=kv_cache,
|
||||
block_table=block_table,
|
||||
slot_mapping=branch_slot_mapping,
|
||||
seqlen_k=sequence_position + q_len,
|
||||
backend=AttentionBackendEnum.FLASH_ATTN,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
|
||||
# Compare the outputs.
|
||||
assert torch.allclose(
|
||||
tree_attn_output[:, branch_indices],
|
||||
flash_attn_output,
|
||||
atol=7.81e-3,
|
||||
), (
|
||||
f"outputs are not close for "
|
||||
f"batch_size: {batch_size}, "
|
||||
f"num_heads: {num_heads}, "
|
||||
f"sequence_position: {sequence_position}, "
|
||||
f"tree_attn_mask: {tree_attn_mask}, "
|
||||
f"q_index: {q_index}."
|
||||
)
|
||||
|
||||
|
||||
def _gen_slot_mapping(
|
||||
positions: torch.Tensor, block_table: torch.Tensor, block_size: int
|
||||
):
|
||||
block_indices = positions // block_size
|
||||
blocks = block_table.gather(dim=1, index=block_indices)
|
||||
return (blocks * block_size + positions % block_size).view(-1)
|
||||
Reference in New Issue
Block a user