Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import ast
|
||||
from dataclasses import replace
|
||||
from importlib.util import find_spec
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -20,17 +20,13 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.tree_attn import (
|
||||
TreeAttentionMetadata,
|
||||
@@ -38,14 +34,15 @@ from vllm.v1.attention.backends.tree_attn import (
|
||||
)
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.utils import (
|
||||
PADDING_SLOT_ID,
|
||||
compute_new_slot_mapping,
|
||||
copy_and_expand_eagle_inputs_kernel,
|
||||
create_vllm_config_for_draft_model,
|
||||
eagle_prepare_inputs_padded_kernel,
|
||||
eagle_prepare_next_token_padded_kernel,
|
||||
extend_all_queries_by_N,
|
||||
@@ -53,6 +50,7 @@ from vllm.v1.spec_decode.utils import (
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -68,6 +66,7 @@ class SpecDecodeBaseProposer:
|
||||
self.vllm_config = vllm_config
|
||||
assert vllm_config.speculative_config is not None
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.draft_vllm_config = create_vllm_config_for_draft_model(vllm_config)
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
||||
@@ -79,6 +78,9 @@ class SpecDecodeBaseProposer:
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||
|
||||
self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp
|
||||
self.layer_num = 1
|
||||
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
@@ -113,21 +115,19 @@ class SpecDecodeBaseProposer:
|
||||
vllm_config.model_config
|
||||
)
|
||||
|
||||
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
self.draft_attn_groups: list[AttentionGroup] = []
|
||||
self.kv_cache_gid: int = -1
|
||||
self.eagle3_use_aux_hidden_state: bool = (
|
||||
self._get_eagle3_use_aux_hidden_state_from_config()
|
||||
)
|
||||
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
self.compilation_config = self.draft_vllm_config.compilation_config
|
||||
|
||||
# Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
|
||||
# Keys are initialized later via initialize_cudagraph_keys() called from
|
||||
# gpu_model_runner._check_and_update_cudagraph_mode after
|
||||
# adjust_cudagraph_sizes_for_spec_decode is called.
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.draft_vllm_config)
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(
|
||||
@@ -353,7 +353,7 @@ class SpecDecodeBaseProposer:
|
||||
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
|
||||
|
||||
view = self._slot_mapping_buffer[:num_tokens]
|
||||
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
|
||||
return {name: view for name in self._draft_attn_layer_names}
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
"""Initialize cudagraph dispatcher keys for eagle.
|
||||
@@ -372,6 +372,23 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
|
||||
|
||||
def adjust_input(
|
||||
self,
|
||||
batch_size: int,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
target_hidden_states: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
|
||||
return (
|
||||
target_token_ids,
|
||||
target_positions,
|
||||
target_hidden_states,
|
||||
common_attn_metadata,
|
||||
)
|
||||
|
||||
def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Greedy-sample draft tokens from hidden states."""
|
||||
if self.use_local_argmax_reduction:
|
||||
@@ -391,6 +408,7 @@ class SpecDecodeBaseProposer:
|
||||
token_indices_to_sample: torch.Tensor | None,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
@@ -406,6 +424,21 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
(
|
||||
target_token_ids,
|
||||
target_positions,
|
||||
target_hidden_states,
|
||||
common_attn_metadata,
|
||||
) = self.adjust_input(
|
||||
batch_size=batch_size,
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
|
||||
)
|
||||
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata = (
|
||||
self.set_inputs_first_pass(
|
||||
target_token_ids=target_token_ids,
|
||||
@@ -420,109 +453,114 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
if self.attn_metadata_builder is None:
|
||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||
else:
|
||||
attn_metadata_builder = self.attn_metadata_builder
|
||||
per_layer_attn_metadata: dict[str, object] = {}
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(num_tokens)
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model (remove separate indexer)
|
||||
if self.draft_indexer_metadata_builder:
|
||||
draft_indexer_metadata = (
|
||||
self.draft_indexer_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0,
|
||||
|
||||
draft_token_ids_list = []
|
||||
for spec_step_idx in range(self.layer_num):
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
)
|
||||
else:
|
||||
draft_indexer_metadata = None
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
for layer_name in self.indexer_layer_names:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_dp_padded
|
||||
)
|
||||
num_input_tokens = batch_desc.num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1 or self.parallel_drafting:
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
return draft_token_ids.view(-1, self.num_speculative_tokens)
|
||||
if self.enable_multi_layers_mtp:
|
||||
model_kwargs["spec_step_idx"] = spec_step_idx
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
if self.enable_multi_layers_mtp:
|
||||
logits = self.model.compute_logits(
|
||||
sample_hidden_states, spec_step_idx=spec_step_idx
|
||||
)
|
||||
else:
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
if spec_step_idx < self.layer_num - 1:
|
||||
prev_token_ids = self.input_ids[:num_tokens].clone()
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
next_token_ids = draft_token_ids_list[-1].int()
|
||||
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata = (
|
||||
self.set_inputs_first_pass(
|
||||
target_token_ids=prev_token_ids,
|
||||
next_token_ids=next_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=hidden_states,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
cad=common_attn_metadata,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
)
|
||||
|
||||
# Early exit if all draft tokens are generated in one pass
|
||||
if self.num_speculative_tokens == self.layer_num or self.parallel_drafting:
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, token_indices_to_sample]
|
||||
else:
|
||||
positions = self.positions[token_indices_to_sample]
|
||||
if self.method in (
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
):
|
||||
if self.method == "mtp":
|
||||
hidden_states = self.hidden_states[token_indices_to_sample]
|
||||
else:
|
||||
hidden_states = hidden_states[token_indices_to_sample]
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
# Draft using tree attention - requires full logits for top-k
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if self.enable_multi_layers_mtp:
|
||||
raise NotImplementedError(
|
||||
"Speculative Decoding with multi-layer MTP and tree attention "
|
||||
"is not supported yet."
|
||||
)
|
||||
# Draft using tree attention.
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
@@ -534,32 +572,20 @@ class SpecDecodeBaseProposer:
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
|
||||
if self.allowed_attn_types is not None and not isinstance(
|
||||
attn_metadata, self.allowed_attn_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
"decoding with num_speculative_tokens > layer_num: "
|
||||
f"{type(attn_metadata)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}"
|
||||
)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
|
||||
cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
|
||||
self._determine_batch_execution_and_padding(batch_size)
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
batch_size_dp_padded
|
||||
)
|
||||
input_batch_size = batch_desc.num_tokens
|
||||
if batch_size_across_dp is not None:
|
||||
batch_size_across_dp[self.dp_rank] = input_batch_size
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
|
||||
@@ -577,7 +603,7 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._seq_lens_cpu = None
|
||||
common_attn_metadata._num_computed_tokens_cpu = None
|
||||
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
for token_index in range(self.num_speculative_tokens - self.layer_num):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@@ -627,7 +653,8 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||
# Use the first draft attention group's kv_cache_spec for block_size
|
||||
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||
if self.uses_mrope:
|
||||
# all dimensions of positions are the same
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
@@ -653,11 +680,13 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
|
||||
)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
@@ -683,7 +712,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=input_batch_size,
|
||||
num_tokens_across_dp=batch_size_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
@@ -819,18 +848,17 @@ class SpecDecodeBaseProposer:
|
||||
# 2.
|
||||
# Recompute the slot mapping based on the new positions and
|
||||
# rejection mask.
|
||||
builder = (
|
||||
self._get_attention_metadata_builder()
|
||||
if self.attn_metadata_builder is None
|
||||
else self.attn_metadata_builder
|
||||
)
|
||||
# Use the first draft attention group's kv_cache_spec for block_size
|
||||
# (all draft layers share the same kv-cache group)
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||
new_slot_mapping = compute_new_slot_mapping(
|
||||
cad=cad,
|
||||
new_positions=self.positions[:total_num_output_tokens],
|
||||
is_rejected_token_mask=self.is_rejected_token_mask[
|
||||
:total_num_output_tokens
|
||||
],
|
||||
block_size=builder.kv_cache_spec.block_size,
|
||||
block_size=block_size,
|
||||
num_new_tokens=self.net_num_new_slots_per_request,
|
||||
max_model_len=self.max_model_len,
|
||||
)
|
||||
@@ -880,6 +908,69 @@ class SpecDecodeBaseProposer:
|
||||
next_token_ids, dtype=torch.int32, device=self.input_ids.device
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
def eagle_prepare_next_token_padded(
|
||||
self,
|
||||
sampled_token_ids, # [num_reqs, num_sampled_tokens_per_req]
|
||||
discard_request_mask, # [num_reqs]
|
||||
backup_next_token_ids, # [num_reqs]
|
||||
vocab_size,
|
||||
):
|
||||
"""
|
||||
PyTorch implementation of eagle_prepare_next_token_padded kernel.
|
||||
|
||||
Args:
|
||||
sampled_token_ids: Tensor of shape [num_reqs, num_sampled_tokens_per_req]
|
||||
containing sampled token ids (-1 for rejected tokens)
|
||||
discard_request_mask: Boolean tensor of shape [num_reqs] indicating
|
||||
which requests should be discarded
|
||||
backup_next_token_ids: Tensor of shape [num_reqs] containing backup
|
||||
token ids for when no valid tokens are found
|
||||
vocab_size: Vocabulary size for validity checking
|
||||
|
||||
Returns:
|
||||
next_token_ids: Tensor of shape [num_reqs] containing the next token
|
||||
to sample (last accepted token or backup)
|
||||
valid_sampled_tokens_count: Tensor of shape [num_reqs] containing the
|
||||
number of valid (1 + accepted) tokens
|
||||
"""
|
||||
num_reqs = sampled_token_ids.shape[0]
|
||||
num_sampled_tokens_per_req = sampled_token_ids.shape[1]
|
||||
|
||||
# Initialize output tensors
|
||||
next_token_ids = torch.empty(num_reqs, dtype=sampled_token_ids.dtype, device=sampled_token_ids.device)
|
||||
valid_sampled_tokens_count = torch.zeros(num_reqs, dtype=torch.int32, device=sampled_token_ids.device)
|
||||
|
||||
# Process each request
|
||||
for req_idx in range(num_reqs):
|
||||
if discard_request_mask[req_idx]:
|
||||
# Discarded request: use backup token and valid_count=0
|
||||
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
|
||||
valid_sampled_tokens_count[req_idx] = 0
|
||||
else:
|
||||
# Get sampled tokens for this request
|
||||
tokens = sampled_token_ids[req_idx]
|
||||
|
||||
# Find valid tokens (not -1 and within vocabulary range)
|
||||
is_valid = (tokens != -1) & (tokens < vocab_size)
|
||||
valid_count = is_valid.sum().item()
|
||||
|
||||
if valid_count > 0:
|
||||
# Find the last valid token index
|
||||
# Get indices where is_valid is True
|
||||
valid_indices = torch.where(is_valid)[0]
|
||||
last_valid_idx = valid_indices[-1].item()
|
||||
|
||||
# Get the token at that index
|
||||
last_valid_token = tokens[last_valid_idx]
|
||||
next_token_ids[req_idx] = last_valid_token
|
||||
else:
|
||||
# No valid tokens, use backup token
|
||||
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
|
||||
|
||||
valid_sampled_tokens_count[req_idx] = valid_count
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
@@ -910,31 +1001,15 @@ class SpecDecodeBaseProposer:
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
backup_tokens_gpu = self.backup_next_token_ids.gpu
|
||||
|
||||
batch_size, num_tokens = sampled_token_ids.shape
|
||||
device = sampled_token_ids.device
|
||||
|
||||
assert discard_request_mask.dtype == torch.bool
|
||||
assert backup_tokens_gpu.dtype == torch.int32
|
||||
|
||||
next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
|
||||
valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
|
||||
|
||||
# Kernel grid: one program per request (row)
|
||||
grid = (batch_size,)
|
||||
|
||||
# Find the next power of 2 for block sizes
|
||||
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
|
||||
eagle_prepare_next_token_padded_kernel[grid](
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = self.eagle_prepare_next_token_padded(
|
||||
sampled_token_ids,
|
||||
discard_request_mask,
|
||||
backup_tokens_gpu,
|
||||
next_token_ids,
|
||||
valid_sampled_tokens_count,
|
||||
gpu_input_batch.vocab_size,
|
||||
num_tokens,
|
||||
batch_size,
|
||||
sampled_token_ids.stride(0),
|
||||
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
@@ -974,6 +1049,8 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
|
||||
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
@@ -981,7 +1058,9 @@ class SpecDecodeBaseProposer:
|
||||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
seq_lens_np = common_attn_metadata.seq_lens_np,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
key_start_loc=key_start_loc,
|
||||
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
@@ -1014,9 +1093,7 @@ class SpecDecodeBaseProposer:
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
tree_attn_metadata_builder = self.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder()
|
||||
tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
|
||||
|
||||
total_num_drafts = self.cu_drafts_per_level[0]
|
||||
@@ -1092,10 +1169,11 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata=common_attn_metadata, draft_index=level + 1
|
||||
)
|
||||
|
||||
# Apply new attention metadata to all layers.
|
||||
# Apply new attention metadata to all draft layers.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(
|
||||
@@ -1131,7 +1209,7 @@ class SpecDecodeBaseProposer:
|
||||
# Run the model.
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
@@ -1209,6 +1287,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
|
||||
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
@@ -1261,6 +1340,7 @@ class SpecDecodeBaseProposer:
|
||||
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
|
||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
key_start_loc=key_start_loc,
|
||||
_seq_lens_cpu=new_seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
@@ -1289,7 +1369,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_model_tag("eagle_head"):
|
||||
model = get_model(
|
||||
vllm_config=self.vllm_config,
|
||||
vllm_config=self.draft_vllm_config,
|
||||
model_config=self.speculative_config.draft_model_config,
|
||||
load_config=self.speculative_config.draft_load_config,
|
||||
)
|
||||
@@ -1302,43 +1382,17 @@ class SpecDecodeBaseProposer:
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model
|
||||
target_indexer_layer_names = set(
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
).keys()
|
||||
)
|
||||
|
||||
self.model = self._get_model()
|
||||
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
- target_attn_layer_names
|
||||
# Find draft layers (attention layers added by draft model)
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.draft_vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
indexer_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
self._draft_attn_layer_names = (
|
||||
set(all_attn_layers.keys()) - target_attn_layer_names
|
||||
)
|
||||
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
|
||||
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
|
||||
self.indexer_layer_names = list(draft_indexer_layer_names)
|
||||
|
||||
if self.indexer_layer_names:
|
||||
first_layer = self.indexer_layer_names[0]
|
||||
self.draft_indexer_metadata_builder = (
|
||||
indexer_layers[first_layer]
|
||||
.get_attn_backend()
|
||||
.get_builder_cls()(
|
||||
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
|
||||
self.indexer_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.draft_indexer_metadata_builder = None
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
# Even if the target model is multimodal, we can also use
|
||||
@@ -1568,25 +1622,17 @@ class SpecDecodeBaseProposer:
|
||||
self.num_speculative_tokens if not is_graph_capturing else 1
|
||||
):
|
||||
if fwd_idx <= 1:
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
|
||||
)
|
||||
if use_cudagraphs:
|
||||
cudagraph_runtime_mode, batch_desc = (
|
||||
self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
num_input_tokens = batch_desc.num_tokens
|
||||
else:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
num_input_tokens = num_tokens_dp_padded
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
)
|
||||
|
||||
# Make sure to use EAGLE's own buffer during cudagraph capture.
|
||||
if (
|
||||
self.attn_layer_names
|
||||
self._draft_attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and self.attn_layer_names[0] in slot_mappings
|
||||
and next(iter(self._draft_attn_layer_names)) in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
@@ -1594,7 +1640,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
@@ -1616,31 +1662,6 @@ class SpecDecodeBaseProposer:
|
||||
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
self.model(**kwargs)
|
||||
|
||||
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
|
||||
Returns:
|
||||
The metadata builders for EAGLE layers.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||
"""
|
||||
builder = None
|
||||
chosen_layer = self.attn_layer_names[0]
|
||||
|
||||
for kv_cache_group in self.runner.attn_groups:
|
||||
for attn_group in kv_cache_group:
|
||||
if chosen_layer in attn_group.layer_names:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
break
|
||||
if builder is not None:
|
||||
break
|
||||
|
||||
assert builder is not None, (
|
||||
"Failed to find attention metadata builder for EAGLE layers."
|
||||
)
|
||||
return builder
|
||||
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||
"""
|
||||
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
|
||||
@@ -1673,35 +1694,114 @@ class SpecDecodeBaseProposer:
|
||||
set(
|
||||
[
|
||||
kv_cache_groups[layer_name]
|
||||
for layer_name in self.attn_layer_names
|
||||
for layer_name in self._draft_attn_layer_names
|
||||
]
|
||||
)
|
||||
)
|
||||
== 1
|
||||
), "All drafting layers should belong to the same kv cache group"
|
||||
|
||||
def _pad_batch_across_dp(
|
||||
def initialize_attn_backend(
|
||||
self,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
) -> tuple[int, torch.Tensor]:
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
|
||||
!= CUDAGraphMode.NONE,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=None,
|
||||
num_scheduled_tokens_per_request=None,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kernel_block_sizes: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AttentionGroups for draft layers using kv_cache_config.
|
||||
Called from the model runner's initialize_metadata_builders.
|
||||
"""
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.draft_vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
|
||||
|
||||
num_tokens_dp_padded = num_tokens_padded
|
||||
if num_toks_across_dp is not None:
|
||||
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
|
||||
return num_tokens_dp_padded, num_toks_across_dp
|
||||
# Find which kv_cache_group the draft layers belong to
|
||||
self.validate_same_kv_cache_group(kv_cache_config)
|
||||
kv_cache_spec = None
|
||||
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
if self._draft_attn_layer_names & set(group.layer_names):
|
||||
self.kv_cache_gid = gid
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
break
|
||||
|
||||
attention_groups: dict[tuple[str, str], AttentionGroup] = {}
|
||||
if kv_cache_spec is not None:
|
||||
for layer_name in self._draft_attn_layer_names:
|
||||
attn_backend = all_attn_layers[layer_name].get_attn_backend()
|
||||
backend_key = attn_backend.full_cls_name()
|
||||
if backend_key not in attention_groups:
|
||||
layer_kv_cache_spec = kv_cache_spec
|
||||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
|
||||
layer_name
|
||||
]
|
||||
|
||||
kernel_block_size = (
|
||||
kernel_block_sizes[self.kv_cache_gid]
|
||||
if kernel_block_sizes is not None
|
||||
and self.kv_cache_gid < len(kernel_block_sizes)
|
||||
else None
|
||||
)
|
||||
attn_group = AttentionGroup(
|
||||
backend=attn_backend,
|
||||
layer_names=[layer_name],
|
||||
kv_cache_spec=layer_kv_cache_spec,
|
||||
kv_cache_group_id=self.kv_cache_gid,
|
||||
)
|
||||
attn_group.create_metadata_builders(
|
||||
self.draft_vllm_config,
|
||||
self.device,
|
||||
kernel_block_size=kernel_block_size,
|
||||
)
|
||||
attention_groups[backend_key] = attn_group
|
||||
else:
|
||||
attention_groups[backend_key].layer_names.append(layer_name)
|
||||
|
||||
self.draft_attn_groups = list(attention_groups.values())
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens,
|
||||
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
|
||||
)
|
||||
num_tokens_padded = batch_desc.num_tokens
|
||||
|
||||
# Extra coordination when running data-parallel since we need to
|
||||
# coordinate across ranks
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_tokens_across_dp = False, None
|
||||
if self.draft_vllm_config.parallel_config.data_parallel_size > 1:
|
||||
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.draft_vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode.value,
|
||||
)
|
||||
)
|
||||
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
|
||||
|
||||
# Extract DP-synced values
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.dp_rank
|
||||
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
|
||||
# Re-dispatch with DP padding so we have the correct
|
||||
# batch_descriptor
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_padded,
|
||||
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
|
||||
)
|
||||
# Assert to make sure the agreed upon token count is correct
|
||||
# otherwise num_tokens_across_dp will no-longer be valid
|
||||
assert batch_desc.num_tokens == num_tokens_padded
|
||||
num_tokens_across_dp[dp_rank] = num_tokens_padded
|
||||
|
||||
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
|
||||
|
||||
|
||||
class EagleProposer(SpecDecodeBaseProposer):
|
||||
|
||||
395
vllm/v1/spec_decode/extract_hidden_states.py
Normal file
395
vllm/v1/spec_decode/extract_hidden_states.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer import has_kv_transfer_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class ExtractHiddenStatesProposer:
|
||||
def __init__(self, vllm_config: VllmConfig, device):
|
||||
assert vllm_config.speculative_config is not None
|
||||
|
||||
assert vllm_config.speculative_config.num_speculative_tokens == 1
|
||||
if vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
raise ValueError(
|
||||
"disable_padded_drafter_batch is not supported with "
|
||||
"extract_hidden_states method"
|
||||
)
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Model and attention layer tracking (initialized in load_model)
|
||||
self.model: nn.Module | None = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
|
||||
# Maximum number of tokens for buffers
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
|
||||
)
|
||||
|
||||
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
|
||||
if not layer_ids:
|
||||
raise ValueError(
|
||||
"eagle_aux_hidden_state_layer_ids must be set in the draft "
|
||||
"model config for extract_hidden_states method"
|
||||
)
|
||||
self.num_hidden_states = len(layer_ids)
|
||||
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
|
||||
self._slot_mapping_buffer = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
target_hidden_states: list[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
scheduler_output: SchedulerOutput,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
|
||||
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
|
||||
|
||||
The ExtractHiddenStatesModel caches the hidden states in the KV cache
|
||||
without performing actual attention computation. This allows us to
|
||||
extract and store hidden states for later use (e.g., KV transfer).
|
||||
|
||||
This proposer doesn't actually perform speculation - it returns the
|
||||
sampled tokens as "draft" tokens, ensuring they always verify (match).
|
||||
The main purpose is to cache hidden states, not to speculate.
|
||||
|
||||
Args:
|
||||
sampled_token_ids: Sampled token IDs from the target model
|
||||
target_hidden_states: List of hidden state tensors from target model
|
||||
(one per aux hidden state layer)
|
||||
common_attn_metadata: Attention metadata
|
||||
scheduler_output: Scheduler output for KV connector
|
||||
slot_mappings: Slot mappings for KV cache (unused, provided for
|
||||
interface compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Draft tokens matching sampled tokens, shape [batch_size, 1]
|
||||
- KV connector output (if KV transfer is active), else None
|
||||
"""
|
||||
assert self.model is not None and isinstance(target_hidden_states, list)
|
||||
|
||||
# target_hidden_states is a list of tensors (one per layer)
|
||||
# Each tensor has shape [num_tokens, hidden_size]
|
||||
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
|
||||
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
|
||||
num_tokens = stacked_hidden_states.shape[0]
|
||||
|
||||
# Copy hidden states to buffer
|
||||
self.hidden_states[:num_tokens] = stacked_hidden_states
|
||||
|
||||
assert self.attn_metadata_builder is not None
|
||||
attn_metadata = self.attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
|
||||
# We assume all cache-only layers belong to the same KV cache group,
|
||||
# thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(num_tokens)
|
||||
)
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
with (
|
||||
set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
),
|
||||
(
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
|
||||
if has_kv_transfer_group()
|
||||
else nullcontext()
|
||||
) as kv_connector_output,
|
||||
):
|
||||
self.model(
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
|
||||
# Return the sampled tokens as "draft" tokens
|
||||
# Shape: [batch_size, 1] to match num_speculative_tokens=1
|
||||
return sampled_token_ids.unsqueeze(-1), kv_connector_output
|
||||
|
||||
def _get_slot_mapping(
|
||||
self,
|
||||
num_tokens: int,
|
||||
slot_mapping: torch.Tensor | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Return slot_mapping dict for cache-only attention layers.
|
||||
|
||||
If slot_mapping is provided, copies it into the buffer first.
|
||||
"""
|
||||
if slot_mapping is not None:
|
||||
num_actual = slot_mapping.shape[0]
|
||||
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
|
||||
if num_tokens > num_actual:
|
||||
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
|
||||
|
||||
view = self._slot_mapping_buffer[:num_tokens]
|
||||
return {name: view for name in self.attn_layer_names}
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens,
|
||||
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
|
||||
)
|
||||
num_tokens_padded = batch_desc.num_tokens
|
||||
|
||||
# Extra coordination when running data-parallel since we need to
|
||||
# coordinate across ranks
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_tokens_across_dp = False, None
|
||||
if self.vllm_config.parallel_config.data_parallel_size > 1:
|
||||
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode.value,
|
||||
)
|
||||
)
|
||||
assert not should_ubatch, (
|
||||
"DBO ubatching not implemented for extract_hidden_states"
|
||||
)
|
||||
|
||||
# Extract DP-synced values
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.dp_rank
|
||||
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
|
||||
# Re-dispatch with DP padding so we have the correct
|
||||
# batch_descriptor
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_padded,
|
||||
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
|
||||
)
|
||||
# Assert to make sure the agreed upon token count is correct
|
||||
# otherwise num_tokens_across_dp will no-longer be valid
|
||||
assert batch_desc.num_tokens == num_tokens_padded
|
||||
num_tokens_across_dp[dp_rank] = num_tokens_padded
|
||||
|
||||
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
"""Initialize cudagraph dispatcher keys.
|
||||
|
||||
Only supports PIECEWISE cudagraphs (via mixed_mode).
|
||||
Should be called after adjust_cudagraph_sizes_for_spec_decode.
|
||||
"""
|
||||
assert self.vllm_config.speculative_config is not None
|
||||
if (
|
||||
not self.vllm_config.speculative_config.enforce_eager
|
||||
and cudagraph_mode.mixed_mode()
|
||||
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
|
||||
):
|
||||
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
proposer_cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
slot_mappings: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
assert self.model is not None, "Model must be initialized before dummy_run"
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
)
|
||||
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
# Use our own slot mapping buffer during cudagraph capture.
|
||||
if (
|
||||
self.attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and self.attn_layer_names[0] in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
slot_mapping_dict = slot_mappings or {}
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping_dict,
|
||||
):
|
||||
self.model(
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
|
||||
def _build_attn_metadata_builder(
|
||||
self, draft_attn_layers: dict[str, AttentionLayerBase]
|
||||
) -> AttentionMetadataBuilder:
|
||||
"""Build the attention metadata builder from draft attention layers."""
|
||||
if not draft_attn_layers:
|
||||
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
|
||||
layer = next(iter(draft_attn_layers.values()))
|
||||
attn_backend = layer.get_attn_backend()
|
||||
return attn_backend.get_builder_cls()(
|
||||
layer.get_kv_cache_spec(self.vllm_config),
|
||||
self.attn_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare next token IDs for speculative decoding.
|
||||
|
||||
Since num_speculative_tokens == 1, sampled_token_ids has shape
|
||||
(batch_size, 1). For each request we either use the sampled token
|
||||
(if valid and not discarded) or a backup token from the request state.
|
||||
"""
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
device = sampled_token_ids.device
|
||||
|
||||
# Compute backup tokens for discarded / invalid requests
|
||||
backup_tokens_gpu = torch.tensor(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item()
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert discard_request_mask.dtype == torch.bool
|
||||
|
||||
# With num_speculative_tokens == 1, there is exactly one token
|
||||
sampled = sampled_token_ids[:, 0]
|
||||
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
|
||||
valid_sampled_tokens_count = is_valid.to(torch.int32)
|
||||
|
||||
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
|
||||
next_token_ids = torch.where(
|
||||
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
"""Load the ExtractHiddenStatesModel model.
|
||||
|
||||
This method instantiates the ExtractHiddenStatesModel model which is used
|
||||
to cache hidden states during speculative decoding. The model uses
|
||||
cache-only attention (no computation, just caching KV states).
|
||||
|
||||
Args:
|
||||
target_model: The target model (passed for compatibility with
|
||||
EagleProposer interface, but not used here)
|
||||
"""
|
||||
# Get the target model's attention layers before loading draft model
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
|
||||
)
|
||||
|
||||
assert self.vllm_config.speculative_config is not None
|
||||
draft_model_config = self.vllm_config.speculative_config.draft_model_config
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
with set_model_tag("extract_hidden_states"):
|
||||
self.model = get_model(
|
||||
vllm_config=self.vllm_config, model_config=draft_model_config
|
||||
)
|
||||
|
||||
# Identify draft model's attention layers (difference from target)
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
draft_attn_layers = {
|
||||
name: layer
|
||||
for name, layer in all_attn_layers.items()
|
||||
if name not in target_attn_layer_names
|
||||
}
|
||||
self.attn_layer_names = list(draft_attn_layers.keys())
|
||||
assert len(draft_attn_layers) == 1, (
|
||||
"ExtractHiddenStatesModel should have exactly one "
|
||||
f"attention layer, found {len(draft_attn_layers)}"
|
||||
)
|
||||
self.attn_metadata_builder = self._build_attn_metadata_builder(
|
||||
draft_attn_layers
|
||||
)
|
||||
|
||||
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Validate all drafting layers belong to the same KV cache group.
|
||||
|
||||
With exactly one attention layer (asserted in load_model), this is
|
||||
trivially satisfied.
|
||||
"""
|
||||
assert len(self.attn_layer_names) == 1
|
||||
@@ -64,3 +64,45 @@ class SpecDecodeMetadata:
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiLayerEagleMetadata:
|
||||
# [batch_size]
|
||||
cached_len: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_token_ids: torch.Tensor | None = None
|
||||
# [batch_size, layer_num, hidden_size]
|
||||
cached_hidden_states: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_slot_mappings: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_positions: torch.Tensor | None = None
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
layer_num: int,
|
||||
hidden_size: int,
|
||||
device: torch.device,
|
||||
) -> "MultiLayerEagleMetadata":
|
||||
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
|
||||
cached_token_ids = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int32, device=device
|
||||
)
|
||||
cached_hidden_states = torch.zeros(
|
||||
(1, layer_num, hidden_size), dtype=torch.float32, device=device
|
||||
)
|
||||
cached_slot_mappings = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int64, device=device
|
||||
)
|
||||
cached_positions = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int64, device=device
|
||||
)
|
||||
return cls(
|
||||
cached_len=cached_len,
|
||||
cached_token_ids=cached_token_ids,
|
||||
cached_hidden_states=cached_hidden_states,
|
||||
cached_slot_mappings=cached_slot_mappings,
|
||||
cached_positions=cached_positions,
|
||||
)
|
||||
|
||||
504
vllm/v1/spec_decode/multi_layer_eagle.py
Normal file
504
vllm/v1/spec_decode/multi_layer_eagle.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
|
||||
|
||||
BLOCK_HIDDEN = 128
|
||||
BLOCK_TOKENS = 128
|
||||
|
||||
|
||||
class MultiLayerEagleProposer(EagleProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
|
||||
self.layer_num: int = getattr(
|
||||
self.speculative_config.draft_model_config.hf_text_config,
|
||||
"n_predict", 0
|
||||
)
|
||||
self.num_speculative_tokens: int = (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
)
|
||||
|
||||
def adjust_input(
|
||||
self,
|
||||
batch_size: int,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
target_hidden_states: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
|
||||
assert multi_layer_eagle_metadata is not None
|
||||
if token_indices_to_sample is None:
|
||||
token_indices_to_sample = (
|
||||
common_attn_metadata.query_start_loc[1:] - 1
|
||||
)
|
||||
|
||||
MAX_SHIFT = self.layer_num
|
||||
assert MAX_SHIFT > 0
|
||||
|
||||
prev_token_ids = target_token_ids.clone()
|
||||
prev_positions = target_positions.clone()
|
||||
prev_hidden_states = target_hidden_states.clone()
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
start_token_indices = common_attn_metadata.query_start_loc[:-1]
|
||||
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
pos_for_shift = (
|
||||
target_positions[0]
|
||||
if target_positions.dim() == 2
|
||||
else target_positions
|
||||
)
|
||||
start_token_pos = pos_for_shift[start_token_indices]
|
||||
|
||||
shift = torch.minimum(
|
||||
end_token_indices - token_indices_to_sample,
|
||||
start_token_pos,
|
||||
)
|
||||
shift = torch.clamp(shift, min=0)
|
||||
|
||||
token_indices_to_sample.add_(shift)
|
||||
common_attn_metadata.seq_lens.sub_(shift)
|
||||
|
||||
cached_lens = multi_layer_eagle_metadata.cached_len
|
||||
shift = torch.minimum(shift, cached_lens)
|
||||
|
||||
_multi_layer_eagle_shift_and_cache(
|
||||
batch_size=batch_size,
|
||||
max_shift=MAX_SHIFT,
|
||||
src_token_ids=target_token_ids,
|
||||
dst_token_ids=prev_token_ids,
|
||||
src_positions=target_positions,
|
||||
dst_positions=prev_positions,
|
||||
src_hidden_states=target_hidden_states,
|
||||
dst_hidden_states=prev_hidden_states,
|
||||
src_slot_mapping=slot_mapping,
|
||||
dst_slot_mapping=slot_mapping,
|
||||
start_token_indices=start_token_indices,
|
||||
end_token_indices=end_token_indices,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
shift=shift,
|
||||
cached_lens=cached_lens,
|
||||
cached_prev_token_ids=(
|
||||
multi_layer_eagle_metadata.cached_token_ids
|
||||
),
|
||||
cached_prev_positions=(
|
||||
multi_layer_eagle_metadata.cached_positions
|
||||
),
|
||||
cached_prev_hidden_states=(
|
||||
multi_layer_eagle_metadata.cached_hidden_states
|
||||
),
|
||||
cached_slot_mappings=(
|
||||
multi_layer_eagle_metadata.cached_slot_mappings
|
||||
),
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
return (
|
||||
prev_token_ids,
|
||||
prev_positions,
|
||||
prev_hidden_states,
|
||||
common_attn_metadata,
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
raise Exception(
|
||||
"speculative_config.disable_padded_drafter_batch"
|
||||
" is not supported now for MultiLayerEagleProposer."
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
slot_mappings: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
self._draft_attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and next(iter(self._draft_attn_layer_names)) in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
slot_mapping_dict = slot_mappings or {}
|
||||
|
||||
adjust_input_kwargs = {
|
||||
"batch_size": 1,
|
||||
"target_token_ids": self.input_ids[:num_input_tokens],
|
||||
"target_positions": self._get_positions(num_input_tokens),
|
||||
"target_hidden_states": self.hidden_states[:num_input_tokens],
|
||||
"token_indices_to_sample": torch.tensor(
|
||||
[num_input_tokens - 1],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"common_attn_metadata": CommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
query_start_loc_cpu=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
),
|
||||
key_start_loc=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
seq_lens=torch.tensor(
|
||||
[num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
seq_lens_np=np.array([num_input_tokens], dtype=np.int32),
|
||||
num_reqs=1,
|
||||
num_actual_tokens=num_input_tokens,
|
||||
max_query_len=self.num_speculative_tokens + 1,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_table_tensor=torch.tensor(
|
||||
[], dtype=torch.int32, device=self.device
|
||||
),
|
||||
slot_mapping=self.arange[:num_input_tokens],
|
||||
logits_indices_padded=None,
|
||||
num_logits_indices=None,
|
||||
causal=True,
|
||||
encoder_seq_lens=None,
|
||||
),
|
||||
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
|
||||
layer_num=self.layer_num,
|
||||
hidden_size=self.hidden_size,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
self.adjust_input(**adjust_input_kwargs)
|
||||
|
||||
for fwd_idx in range(self.layer_num):
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping_dict,
|
||||
):
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"hidden_states": self.hidden_states[:num_input_tokens],
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"spec_step_idx": fwd_idx,
|
||||
}
|
||||
|
||||
self.model(**model_kwargs)
|
||||
|
||||
|
||||
def _multi_layer_eagle_shift_and_cache(
|
||||
*,
|
||||
batch_size: int,
|
||||
max_shift: int,
|
||||
src_token_ids: torch.Tensor,
|
||||
dst_token_ids: torch.Tensor,
|
||||
src_positions: torch.Tensor,
|
||||
dst_positions: torch.Tensor,
|
||||
src_hidden_states: torch.Tensor,
|
||||
dst_hidden_states: torch.Tensor,
|
||||
src_slot_mapping: torch.Tensor,
|
||||
dst_slot_mapping: torch.Tensor,
|
||||
start_token_indices: torch.Tensor,
|
||||
end_token_indices: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
shift: torch.Tensor,
|
||||
cached_lens: torch.Tensor,
|
||||
cached_prev_token_ids: torch.Tensor,
|
||||
cached_prev_positions: torch.Tensor,
|
||||
cached_prev_hidden_states: torch.Tensor,
|
||||
cached_slot_mappings: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
):
|
||||
if batch_size == 0:
|
||||
return
|
||||
|
||||
assert max_shift > 0
|
||||
assert cached_prev_positions.is_contiguous()
|
||||
assert cached_prev_token_ids.is_contiguous()
|
||||
assert cached_prev_hidden_states.is_contiguous()
|
||||
assert cached_slot_mappings.is_contiguous()
|
||||
assert src_hidden_states.is_contiguous()
|
||||
assert dst_hidden_states.is_contiguous()
|
||||
|
||||
if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr():
|
||||
src_slot_mapping = src_slot_mapping.clone()
|
||||
|
||||
store_start = torch.maximum(
|
||||
start_token_indices,
|
||||
(token_indices_to_sample + 1 - max_shift),
|
||||
)
|
||||
store_lens = torch.clamp(
|
||||
token_indices_to_sample - store_start + 1,
|
||||
min=0,
|
||||
max=max_shift,
|
||||
)
|
||||
|
||||
max_window_len = int(
|
||||
(
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_token_ids,
|
||||
dst_token_ids,
|
||||
cached_prev_token_ids,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_slot_mapping,
|
||||
dst_slot_mapping,
|
||||
cached_slot_mappings,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_positions,
|
||||
dst_positions,
|
||||
cached_prev_positions,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
hidden_size = int(dst_hidden_states.shape[1])
|
||||
num_hidden_blocks = max(
|
||||
1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN
|
||||
)
|
||||
|
||||
_shift_and_gather_hidden_kernel[
|
||||
(batch_size, num_blocks, num_hidden_blocks)
|
||||
](
|
||||
src_hidden_states,
|
||||
dst_hidden_states,
|
||||
cached_prev_hidden_states,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
HIDDEN_SIZE=hidden_size,
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
BLOCK_HIDDEN=BLOCK_HIDDEN,
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
cached_lens.copy_(store_lens)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _shift_and_gather_cache_1d_kernel(
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
cached_ptr,
|
||||
start_ptr,
|
||||
end_ptr,
|
||||
shift_ptr,
|
||||
cached_len_ptr,
|
||||
store_start_ptr,
|
||||
store_len_ptr,
|
||||
MAX_SHIFT: tl.constexpr,
|
||||
PADDED_SHIFT: tl.constexpr,
|
||||
BLOCK_TOKENS: tl.constexpr,
|
||||
):
|
||||
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
|
||||
# slot mappings, ...).
|
||||
#
|
||||
# For a single sequence (0-based index i within its window):
|
||||
# - Prefix (i < shift):
|
||||
# dst[start + i] = cached[cached_len - shift + i]
|
||||
# - Body (i >= shift):
|
||||
# dst[start + i] = src[start + i - shift]
|
||||
pid_seq = tl.program_id(0)
|
||||
pid_blk = tl.program_id(1)
|
||||
|
||||
start = tl.load(start_ptr + pid_seq).to(tl.int32)
|
||||
end = tl.load(end_ptr + pid_seq).to(tl.int32)
|
||||
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
|
||||
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
|
||||
|
||||
assert cached_len >= shift
|
||||
|
||||
base = pid_blk * BLOCK_TOKENS
|
||||
k = tl.arange(0, BLOCK_TOKENS)
|
||||
offs = base + k
|
||||
dst_idx = start + offs
|
||||
|
||||
window_len = end - start + 1
|
||||
mask = offs < window_len
|
||||
|
||||
base_cached = cached_ptr + pid_seq * MAX_SHIFT
|
||||
cached_idx = cached_len - shift + offs
|
||||
cached_mask = offs < shift
|
||||
val_cached = tl.load(
|
||||
base_cached + cached_idx, mask=mask & cached_mask, other=0
|
||||
)
|
||||
|
||||
src_idx = start + offs - shift
|
||||
val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0)
|
||||
|
||||
val = tl.where(cached_mask, val_cached, val_src)
|
||||
tl.store(dst_ptr + dst_idx, val, mask=mask)
|
||||
|
||||
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
|
||||
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
|
||||
m = tl.arange(0, PADDED_SHIFT)
|
||||
store_mask = m < MAX_SHIFT
|
||||
dst_idx = store_start + m
|
||||
val = tl.load(
|
||||
dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0
|
||||
)
|
||||
tl.store(base_cached + m, val, mask=store_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _shift_and_gather_hidden_kernel(
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
cached_ptr,
|
||||
start_ptr,
|
||||
end_ptr,
|
||||
shift_ptr,
|
||||
cached_len_ptr,
|
||||
store_start_ptr,
|
||||
store_len_ptr,
|
||||
MAX_SHIFT: tl.constexpr,
|
||||
PADDED_SHIFT: tl.constexpr,
|
||||
HIDDEN_SIZE: tl.constexpr,
|
||||
BLOCK_TOKENS: tl.constexpr,
|
||||
BLOCK_HIDDEN: tl.constexpr,
|
||||
):
|
||||
# Per-sequence "shift + gather" for hidden states.
|
||||
# Layout:
|
||||
# - src_ptr / dst_ptr: [num_tokens, hidden_size]
|
||||
# - cached_ptr: [batch_size, MAX_SHIFT, hidden_size]
|
||||
pid_seq = tl.program_id(0)
|
||||
pid_blk = tl.program_id(1)
|
||||
pid_hid = tl.program_id(2)
|
||||
|
||||
start = tl.load(start_ptr + pid_seq).to(tl.int32)
|
||||
end = tl.load(end_ptr + pid_seq).to(tl.int32)
|
||||
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
|
||||
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
|
||||
|
||||
assert cached_len >= shift
|
||||
|
||||
base = pid_blk * BLOCK_TOKENS
|
||||
k = tl.arange(0, BLOCK_TOKENS)
|
||||
tok_offs = base + k
|
||||
dst_tok = start + tok_offs
|
||||
n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)
|
||||
dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
|
||||
window_len = end - start + 1
|
||||
tok_mask = tok_offs < window_len
|
||||
n_mask = n < HIDDEN_SIZE
|
||||
mask = tok_mask[:, None] & n_mask[None, :]
|
||||
|
||||
base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT
|
||||
cached_tok = cached_len - shift + tok_offs
|
||||
cached_ptrs = (
|
||||
base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
)
|
||||
cached_mask = tok_offs < shift
|
||||
val_cached = tl.load(
|
||||
cached_ptrs, mask=mask & cached_mask[:, None], other=0
|
||||
)
|
||||
|
||||
src_tok = start + tok_offs - shift
|
||||
src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0)
|
||||
|
||||
val = tl.where(cached_mask[:, None], val_cached, val_src)
|
||||
tl.store(dst_ptrs, val, mask=mask)
|
||||
|
||||
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
|
||||
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
|
||||
m = tl.arange(0, PADDED_SHIFT)
|
||||
m_mask = (m < MAX_SHIFT) & (m < store_len)
|
||||
store_tok = store_start + m
|
||||
dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
store_ptrs = (
|
||||
base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
)
|
||||
mask = m_mask[:, None] & n_mask[None, :]
|
||||
val = tl.load(dst_ptrs, mask=mask, other=0)
|
||||
tl.store(store_ptrs, val, mask=mask)
|
||||
@@ -157,12 +157,23 @@ def create_vllm_config_for_draft_model(
|
||||
quantized differently, and has potentially different tensor_parallel_size.
|
||||
This function creates a new vllm_config configured for the drafter.
|
||||
The vllm_config is useful when loading the draft model with get_model().
|
||||
|
||||
This helper returns the original target config for the common case and only
|
||||
rewrites rank/parallel info when the drafter is configured to run locally
|
||||
on the last target PP stage. This keeps runtime behavior unchanged for the
|
||||
common case while still handling PP rank remapping.
|
||||
"""
|
||||
old = target_model_vllm_config
|
||||
assert old.speculative_config is not None, "speculative_config is not set"
|
||||
old_spec_config = old.speculative_config
|
||||
needs_rank_remap = old_spec_config.needs_partial_pp_draft_remap(old.parallel_config)
|
||||
if not needs_rank_remap:
|
||||
return old
|
||||
|
||||
draft_rank = old_spec_config.resolve_partial_pp_draft_rank(old.parallel_config)
|
||||
|
||||
new_parallel_config = replace(
|
||||
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
|
||||
old_spec_config.draft_parallel_config, rank=draft_rank
|
||||
)
|
||||
new: VllmConfig = replace(
|
||||
old,
|
||||
|
||||
Reference in New Issue
Block a user