Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -3,7 +3,7 @@
import ast
from dataclasses import replace
from importlib.util import find_spec
from typing import cast
from typing import Any, cast
import numpy as np
import torch
@@ -20,17 +20,13 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import (
TreeAttentionMetadata,
@@ -38,14 +34,15 @@ from vllm.v1.attention.backends.tree_attn import (
)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID,
compute_new_slot_mapping,
copy_and_expand_eagle_inputs_kernel,
create_vllm_config_for_draft_model,
eagle_prepare_inputs_padded_kernel,
eagle_prepare_next_token_padded_kernel,
extend_all_queries_by_N,
@@ -53,6 +50,7 @@ from vllm.v1.spec_decode.utils import (
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
@@ -68,6 +66,7 @@ class SpecDecodeBaseProposer:
self.vllm_config = vllm_config
assert vllm_config.speculative_config is not None
self.speculative_config = vllm_config.speculative_config
self.draft_vllm_config = create_vllm_config_for_draft_model(vllm_config)
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model
@@ -79,6 +78,9 @@ class SpecDecodeBaseProposer:
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp
self.layer_num = 1
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
@@ -113,21 +115,19 @@ class SpecDecodeBaseProposer:
vllm_config.model_config
)
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.draft_attn_groups: list[AttentionGroup] = []
self.kv_cache_gid: int = -1
self.eagle3_use_aux_hidden_state: bool = (
self._get_eagle3_use_aux_hidden_state_from_config()
)
self.compilation_config = self.vllm_config.compilation_config
self.compilation_config = self.draft_vllm_config.compilation_config
# Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
# Keys are initialized later via initialize_cudagraph_keys() called from
# gpu_model_runner._check_and_update_cudagraph_mode after
# adjust_cudagraph_sizes_for_spec_decode is called.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.cudagraph_dispatcher = CudagraphDispatcher(self.draft_vllm_config)
# persistent buffers for cuda graph
self.input_ids = torch.zeros(
@@ -353,7 +353,7 @@ class SpecDecodeBaseProposer:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
return {name: view for name in self._draft_attn_layer_names}
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
@@ -372,6 +372,23 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
return (
target_token_ids,
target_positions,
target_hidden_states,
common_attn_metadata,
)
def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Greedy-sample draft tokens from hidden states."""
if self.use_local_argmax_reduction:
@@ -391,6 +408,7 @@ class SpecDecodeBaseProposer:
token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor]
@@ -406,6 +424,21 @@ class SpecDecodeBaseProposer:
)
assert target_hidden_states.shape[-1] == self.hidden_size
(
target_token_ids,
target_positions,
target_hidden_states,
common_attn_metadata,
) = self.adjust_input(
batch_size=batch_size,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
)
num_tokens, token_indices_to_sample, common_attn_metadata = (
self.set_inputs_first_pass(
target_token_ids=target_token_ids,
@@ -420,109 +453,114 @@ class SpecDecodeBaseProposer:
assert self.runner is not None
if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
# FIXME: support hybrid kv for draft model (remove separate indexer)
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
draft_token_ids_list = []
for spec_step_idx in range(self.layer_num):
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
)
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded
)
num_input_tokens = batch_desc.num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
last_hidden_states, hidden_states = ret_hidden_states
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
sample_hidden_states = last_hidden_states[token_indices_to_sample]
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1 or self.parallel_drafting:
draft_token_ids = self._greedy_sample(sample_hidden_states)
return draft_token_ids.view(-1, self.num_speculative_tokens)
if self.enable_multi_layers_mtp:
model_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(
per_layer_attn_metadata,
self.draft_vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[token_indices_to_sample]
if self.enable_multi_layers_mtp:
logits = self.model.compute_logits(
sample_hidden_states, spec_step_idx=spec_step_idx
)
else:
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids = logits.argmax(dim=-1)
# Generate the remaining draft tokens.
draft_token_ids_list.append(draft_token_ids)
if spec_step_idx < self.layer_num - 1:
prev_token_ids = self.input_ids[:num_tokens].clone()
hidden_states = hidden_states[:num_tokens]
next_token_ids = draft_token_ids_list[-1].int()
num_tokens, token_indices_to_sample, common_attn_metadata = (
self.set_inputs_first_pass(
target_token_ids=prev_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=hidden_states,
token_indices_to_sample=token_indices_to_sample,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
)
# Early exit if all draft tokens are generated in one pass
if self.num_speculative_tokens == self.layer_num or self.parallel_drafting:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
if self.uses_mrope:
positions = self.mrope_positions[:, token_indices_to_sample]
else:
positions = self.positions[token_indices_to_sample]
if self.method in (
"deepseek_mtp",
"ernie_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
):
if self.method == "mtp":
hidden_states = self.hidden_states[token_indices_to_sample]
else:
hidden_states = hidden_states[token_indices_to_sample]
if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention - requires full logits for top-k
logits = self.model.compute_logits(sample_hidden_states)
if self.enable_multi_layers_mtp:
raise NotImplementedError(
"Speculative Decoding with multi-layer MTP and tree attention "
"is not supported yet."
)
# Draft using tree attention.
draft_token_ids_list = self.propose_tree(
batch_size=batch_size,
logits=logits,
@@ -534,32 +572,20 @@ class SpecDecodeBaseProposer:
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = self._greedy_sample(sample_hidden_states)
if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types
):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
"decoding with num_speculative_tokens > layer_num: "
f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}"
)
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
self._determine_batch_execution_and_padding(batch_size)
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
batch_size_dp_padded
)
input_batch_size = batch_desc.num_tokens
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
@@ -577,7 +603,7 @@ class SpecDecodeBaseProposer:
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
for token_index in range(self.num_speculative_tokens - 1):
for token_index in range(self.num_speculative_tokens - self.layer_num):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
@@ -627,7 +653,8 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping.
block_size = attn_metadata_builder.kv_cache_spec.block_size
# Use the first draft attention group's kv_cache_spec for block_size
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
if self.uses_mrope:
# all dimensions of positions are the same
block_numbers = clamped_positions[0] // block_size
@@ -653,11 +680,13 @@ class SpecDecodeBaseProposer:
)
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
@@ -683,7 +712,7 @@ class SpecDecodeBaseProposer:
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
self.draft_vllm_config,
num_tokens=input_batch_size,
num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
@@ -819,18 +848,17 @@ class SpecDecodeBaseProposer:
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
builder = (
self._get_attention_metadata_builder()
if self.attn_metadata_builder is None
else self.attn_metadata_builder
)
# Use the first draft attention group's kv_cache_spec for block_size
# (all draft layers share the same kv-cache group)
assert len(self.draft_attn_groups) > 0
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[
:total_num_output_tokens
],
block_size=builder.kv_cache_spec.block_size,
block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
@@ -880,6 +908,69 @@ class SpecDecodeBaseProposer:
next_token_ids, dtype=torch.int32, device=self.input_ids.device
)
return next_token_ids
def eagle_prepare_next_token_padded(
self,
sampled_token_ids, # [num_reqs, num_sampled_tokens_per_req]
discard_request_mask, # [num_reqs]
backup_next_token_ids, # [num_reqs]
vocab_size,
):
"""
PyTorch implementation of eagle_prepare_next_token_padded kernel.
Args:
sampled_token_ids: Tensor of shape [num_reqs, num_sampled_tokens_per_req]
containing sampled token ids (-1 for rejected tokens)
discard_request_mask: Boolean tensor of shape [num_reqs] indicating
which requests should be discarded
backup_next_token_ids: Tensor of shape [num_reqs] containing backup
token ids for when no valid tokens are found
vocab_size: Vocabulary size for validity checking
Returns:
next_token_ids: Tensor of shape [num_reqs] containing the next token
to sample (last accepted token or backup)
valid_sampled_tokens_count: Tensor of shape [num_reqs] containing the
number of valid (1 + accepted) tokens
"""
num_reqs = sampled_token_ids.shape[0]
num_sampled_tokens_per_req = sampled_token_ids.shape[1]
# Initialize output tensors
next_token_ids = torch.empty(num_reqs, dtype=sampled_token_ids.dtype, device=sampled_token_ids.device)
valid_sampled_tokens_count = torch.zeros(num_reqs, dtype=torch.int32, device=sampled_token_ids.device)
# Process each request
for req_idx in range(num_reqs):
if discard_request_mask[req_idx]:
# Discarded request: use backup token and valid_count=0
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
valid_sampled_tokens_count[req_idx] = 0
else:
# Get sampled tokens for this request
tokens = sampled_token_ids[req_idx]
# Find valid tokens (not -1 and within vocabulary range)
is_valid = (tokens != -1) & (tokens < vocab_size)
valid_count = is_valid.sum().item()
if valid_count > 0:
# Find the last valid token index
# Get indices where is_valid is True
valid_indices = torch.where(is_valid)[0]
last_valid_idx = valid_indices[-1].item()
# Get the token at that index
last_valid_token = tokens[last_valid_idx]
next_token_ids[req_idx] = last_valid_token
else:
# No valid tokens, use backup token
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
valid_sampled_tokens_count[req_idx] = valid_count
return next_token_ids, valid_sampled_tokens_count
def prepare_next_token_ids_padded(
self,
@@ -910,31 +1001,15 @@ class SpecDecodeBaseProposer:
self.backup_next_token_ids.copy_to_gpu(num_reqs)
backup_tokens_gpu = self.backup_next_token_ids.gpu
batch_size, num_tokens = sampled_token_ids.shape
device = sampled_token_ids.device
assert discard_request_mask.dtype == torch.bool
assert backup_tokens_gpu.dtype == torch.int32
next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
# Kernel grid: one program per request (row)
grid = (batch_size,)
# Find the next power of 2 for block sizes
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
eagle_prepare_next_token_padded_kernel[grid](
next_token_ids, valid_sampled_tokens_count = self.eagle_prepare_next_token_padded(
sampled_token_ids,
discard_request_mask,
backup_tokens_gpu,
next_token_ids,
valid_sampled_tokens_count,
gpu_input_batch.vocab_size,
num_tokens,
batch_size,
sampled_token_ids.stride(0),
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
)
return next_token_ids, valid_sampled_tokens_count
@@ -974,6 +1049,8 @@ class SpecDecodeBaseProposer:
)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
key_start_loc = common_attn_metadata.key_start_loc
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
total_num_tokens = query_start_loc_cpu[-1].item()
@@ -981,7 +1058,9 @@ class SpecDecodeBaseProposer:
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
seq_lens=common_attn_metadata.seq_lens,
seq_lens_np = common_attn_metadata.seq_lens_np,
query_start_loc_cpu=query_start_loc_cpu,
key_start_loc=key_start_loc,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
@@ -1014,9 +1093,7 @@ class SpecDecodeBaseProposer:
| list[dict[str, torch.Tensor]]
| None = None,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_groups[0][
0
].get_metadata_builder()
tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
total_num_drafts = self.cu_drafts_per_level[0]
@@ -1092,10 +1169,11 @@ class SpecDecodeBaseProposer:
common_attn_metadata=common_attn_metadata, draft_index=level + 1
)
# Apply new attention metadata to all layers.
# Apply new attention metadata to all draft layers.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for attn_group in self.draft_attn_groups:
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# Consider max model length.
attn_metadata.max_seq_len = min(
@@ -1131,7 +1209,7 @@ class SpecDecodeBaseProposer:
# Run the model.
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
self.draft_vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
@@ -1209,6 +1287,7 @@ class SpecDecodeBaseProposer:
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
key_start_loc = common_attn_metadata.key_start_loc
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
@@ -1261,6 +1340,7 @@ class SpecDecodeBaseProposer:
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
key_start_loc=key_start_loc,
_seq_lens_cpu=new_seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
@@ -1289,7 +1369,7 @@ class SpecDecodeBaseProposer:
with set_model_tag("eagle_head"):
model = get_model(
vllm_config=self.vllm_config,
vllm_config=self.draft_vllm_config,
model_config=self.speculative_config.draft_model_config,
load_config=self.speculative_config.draft_load_config,
)
@@ -1302,43 +1382,17 @@ class SpecDecodeBaseProposer:
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
)
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
).keys()
)
self.model = self._get_model()
draft_attn_layer_names = (
get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
- target_attn_layer_names
# Find draft layers (attention layers added by draft model)
all_attn_layers = get_layers_from_vllm_config(
self.draft_vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
self._draft_attn_layer_names = (
set(all_attn_layers.keys()) - target_attn_layer_names
)
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
self.indexer_layer_names = list(draft_indexer_layer_names)
if self.indexer_layer_names:
first_layer = self.indexer_layer_names[0]
self.draft_indexer_metadata_builder = (
indexer_layers[first_layer]
.get_attn_backend()
.get_builder_cls()(
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
self.indexer_layer_names,
self.vllm_config,
self.device,
)
)
else:
self.draft_indexer_metadata_builder = None
if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use
@@ -1568,25 +1622,17 @@ class SpecDecodeBaseProposer:
self.num_speculative_tokens if not is_graph_capturing else 1
):
if fwd_idx <= 1:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
if use_cudagraphs:
cudagraph_runtime_mode, batch_desc = (
self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
num_input_tokens = batch_desc.num_tokens
else:
cudagraph_runtime_mode = CUDAGraphMode.NONE
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
)
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self.attn_layer_names
self._draft_attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
@@ -1594,7 +1640,7 @@ class SpecDecodeBaseProposer:
with set_forward_context(
None,
self.vllm_config,
self.draft_vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
@@ -1616,31 +1662,6 @@ class SpecDecodeBaseProposer:
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
self.model(**kwargs)
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
Returns:
The metadata builders for EAGLE layers.
Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""
builder = None
chosen_layer = self.attn_layer_names[0]
for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
if chosen_layer in attn_group.layer_names:
builder = attn_group.get_metadata_builder()
break
if builder is not None:
break
assert builder is not None, (
"Failed to find attention metadata builder for EAGLE layers."
)
return builder
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
"""
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
@@ -1673,35 +1694,114 @@ class SpecDecodeBaseProposer:
set(
[
kv_cache_groups[layer_name]
for layer_name in self.attn_layer_names
for layer_name in self._draft_attn_layer_names
]
)
)
== 1
), "All drafting layers should belong to the same kv cache group"
def _pad_batch_across_dp(
def initialize_attn_backend(
self,
num_tokens_unpadded: int,
num_tokens_padded: int,
) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
!= CUDAGraphMode.NONE,
num_tokens_padded=num_tokens_padded,
uniform_decode=None,
num_scheduled_tokens_per_request=None,
kv_cache_config: KVCacheConfig,
kernel_block_sizes: list[int] | None = None,
) -> None:
"""
Initialize AttentionGroups for draft layers using kv_cache_config.
Called from the model runner's initialize_metadata_builders.
"""
all_attn_layers = get_layers_from_vllm_config(
self.draft_vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
num_tokens_dp_padded = num_tokens_padded
if num_toks_across_dp is not None:
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
return num_tokens_dp_padded, num_toks_across_dp
# Find which kv_cache_group the draft layers belong to
self.validate_same_kv_cache_group(kv_cache_config)
kv_cache_spec = None
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
if self._draft_attn_layer_names & set(group.layer_names):
self.kv_cache_gid = gid
kv_cache_spec = group.kv_cache_spec
break
attention_groups: dict[tuple[str, str], AttentionGroup] = {}
if kv_cache_spec is not None:
for layer_name in self._draft_attn_layer_names:
attn_backend = all_attn_layers[layer_name].get_attn_backend()
backend_key = attn_backend.full_cls_name()
if backend_key not in attention_groups:
layer_kv_cache_spec = kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
layer_name
]
kernel_block_size = (
kernel_block_sizes[self.kv_cache_gid]
if kernel_block_sizes is not None
and self.kv_cache_gid < len(kernel_block_sizes)
else None
)
attn_group = AttentionGroup(
backend=attn_backend,
layer_names=[layer_name],
kv_cache_spec=layer_kv_cache_spec,
kv_cache_group_id=self.kv_cache_gid,
)
attn_group.create_metadata_builders(
self.draft_vllm_config,
self.device,
kernel_block_size=kernel_block_size,
)
attention_groups[backend_key] = attn_group
else:
attention_groups[backend_key].layer_names.append(layer_name)
self.draft_attn_groups = list(attention_groups.values())
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.draft_vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.draft_vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
class EagleProposer(SpecDecodeBaseProposer):