Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import ast
|
||||
from dataclasses import replace
|
||||
from importlib.util import find_spec
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -20,17 +20,13 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.tree_attn import (
|
||||
TreeAttentionMetadata,
|
||||
@@ -38,14 +34,15 @@ from vllm.v1.attention.backends.tree_attn import (
|
||||
)
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.utils import (
|
||||
PADDING_SLOT_ID,
|
||||
compute_new_slot_mapping,
|
||||
copy_and_expand_eagle_inputs_kernel,
|
||||
create_vllm_config_for_draft_model,
|
||||
eagle_prepare_inputs_padded_kernel,
|
||||
eagle_prepare_next_token_padded_kernel,
|
||||
extend_all_queries_by_N,
|
||||
@@ -53,6 +50,7 @@ from vllm.v1.spec_decode.utils import (
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -68,6 +66,7 @@ class SpecDecodeBaseProposer:
|
||||
self.vllm_config = vllm_config
|
||||
assert vllm_config.speculative_config is not None
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.draft_vllm_config = create_vllm_config_for_draft_model(vllm_config)
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
||||
@@ -79,6 +78,9 @@ class SpecDecodeBaseProposer:
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||
|
||||
self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp
|
||||
self.layer_num = 1
|
||||
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
@@ -113,21 +115,19 @@ class SpecDecodeBaseProposer:
|
||||
vllm_config.model_config
|
||||
)
|
||||
|
||||
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
self.draft_attn_groups: list[AttentionGroup] = []
|
||||
self.kv_cache_gid: int = -1
|
||||
self.eagle3_use_aux_hidden_state: bool = (
|
||||
self._get_eagle3_use_aux_hidden_state_from_config()
|
||||
)
|
||||
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
self.compilation_config = self.draft_vllm_config.compilation_config
|
||||
|
||||
# Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
|
||||
# Keys are initialized later via initialize_cudagraph_keys() called from
|
||||
# gpu_model_runner._check_and_update_cudagraph_mode after
|
||||
# adjust_cudagraph_sizes_for_spec_decode is called.
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.draft_vllm_config)
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(
|
||||
@@ -353,7 +353,7 @@ class SpecDecodeBaseProposer:
|
||||
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
|
||||
|
||||
view = self._slot_mapping_buffer[:num_tokens]
|
||||
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
|
||||
return {name: view for name in self._draft_attn_layer_names}
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
"""Initialize cudagraph dispatcher keys for eagle.
|
||||
@@ -372,6 +372,23 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
|
||||
|
||||
def adjust_input(
|
||||
self,
|
||||
batch_size: int,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
target_hidden_states: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
|
||||
return (
|
||||
target_token_ids,
|
||||
target_positions,
|
||||
target_hidden_states,
|
||||
common_attn_metadata,
|
||||
)
|
||||
|
||||
def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Greedy-sample draft tokens from hidden states."""
|
||||
if self.use_local_argmax_reduction:
|
||||
@@ -391,6 +408,7 @@ class SpecDecodeBaseProposer:
|
||||
token_indices_to_sample: torch.Tensor | None,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
@@ -406,6 +424,21 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
(
|
||||
target_token_ids,
|
||||
target_positions,
|
||||
target_hidden_states,
|
||||
common_attn_metadata,
|
||||
) = self.adjust_input(
|
||||
batch_size=batch_size,
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
|
||||
)
|
||||
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata = (
|
||||
self.set_inputs_first_pass(
|
||||
target_token_ids=target_token_ids,
|
||||
@@ -420,109 +453,114 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
if self.attn_metadata_builder is None:
|
||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||
else:
|
||||
attn_metadata_builder = self.attn_metadata_builder
|
||||
per_layer_attn_metadata: dict[str, object] = {}
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(num_tokens)
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model (remove separate indexer)
|
||||
if self.draft_indexer_metadata_builder:
|
||||
draft_indexer_metadata = (
|
||||
self.draft_indexer_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0,
|
||||
|
||||
draft_token_ids_list = []
|
||||
for spec_step_idx in range(self.layer_num):
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
)
|
||||
else:
|
||||
draft_indexer_metadata = None
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
for layer_name in self.indexer_layer_names:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_dp_padded
|
||||
)
|
||||
num_input_tokens = batch_desc.num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1 or self.parallel_drafting:
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
return draft_token_ids.view(-1, self.num_speculative_tokens)
|
||||
if self.enable_multi_layers_mtp:
|
||||
model_kwargs["spec_step_idx"] = spec_step_idx
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
if self.enable_multi_layers_mtp:
|
||||
logits = self.model.compute_logits(
|
||||
sample_hidden_states, spec_step_idx=spec_step_idx
|
||||
)
|
||||
else:
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
if spec_step_idx < self.layer_num - 1:
|
||||
prev_token_ids = self.input_ids[:num_tokens].clone()
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
next_token_ids = draft_token_ids_list[-1].int()
|
||||
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata = (
|
||||
self.set_inputs_first_pass(
|
||||
target_token_ids=prev_token_ids,
|
||||
next_token_ids=next_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=hidden_states,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
cad=common_attn_metadata,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
)
|
||||
|
||||
# Early exit if all draft tokens are generated in one pass
|
||||
if self.num_speculative_tokens == self.layer_num or self.parallel_drafting:
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, token_indices_to_sample]
|
||||
else:
|
||||
positions = self.positions[token_indices_to_sample]
|
||||
if self.method in (
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
):
|
||||
if self.method == "mtp":
|
||||
hidden_states = self.hidden_states[token_indices_to_sample]
|
||||
else:
|
||||
hidden_states = hidden_states[token_indices_to_sample]
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
# Draft using tree attention - requires full logits for top-k
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if self.enable_multi_layers_mtp:
|
||||
raise NotImplementedError(
|
||||
"Speculative Decoding with multi-layer MTP and tree attention "
|
||||
"is not supported yet."
|
||||
)
|
||||
# Draft using tree attention.
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
@@ -534,32 +572,20 @@ class SpecDecodeBaseProposer:
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
|
||||
if self.allowed_attn_types is not None and not isinstance(
|
||||
attn_metadata, self.allowed_attn_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
"decoding with num_speculative_tokens > layer_num: "
|
||||
f"{type(attn_metadata)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}"
|
||||
)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
|
||||
cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
|
||||
self._determine_batch_execution_and_padding(batch_size)
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
batch_size_dp_padded
|
||||
)
|
||||
input_batch_size = batch_desc.num_tokens
|
||||
if batch_size_across_dp is not None:
|
||||
batch_size_across_dp[self.dp_rank] = input_batch_size
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
|
||||
@@ -577,7 +603,7 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._seq_lens_cpu = None
|
||||
common_attn_metadata._num_computed_tokens_cpu = None
|
||||
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
for token_index in range(self.num_speculative_tokens - self.layer_num):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@@ -627,7 +653,8 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||
# Use the first draft attention group's kv_cache_spec for block_size
|
||||
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||
if self.uses_mrope:
|
||||
# all dimensions of positions are the same
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
@@ -653,11 +680,13 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
|
||||
)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
@@ -683,7 +712,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=input_batch_size,
|
||||
num_tokens_across_dp=batch_size_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
@@ -819,18 +848,17 @@ class SpecDecodeBaseProposer:
|
||||
# 2.
|
||||
# Recompute the slot mapping based on the new positions and
|
||||
# rejection mask.
|
||||
builder = (
|
||||
self._get_attention_metadata_builder()
|
||||
if self.attn_metadata_builder is None
|
||||
else self.attn_metadata_builder
|
||||
)
|
||||
# Use the first draft attention group's kv_cache_spec for block_size
|
||||
# (all draft layers share the same kv-cache group)
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||
new_slot_mapping = compute_new_slot_mapping(
|
||||
cad=cad,
|
||||
new_positions=self.positions[:total_num_output_tokens],
|
||||
is_rejected_token_mask=self.is_rejected_token_mask[
|
||||
:total_num_output_tokens
|
||||
],
|
||||
block_size=builder.kv_cache_spec.block_size,
|
||||
block_size=block_size,
|
||||
num_new_tokens=self.net_num_new_slots_per_request,
|
||||
max_model_len=self.max_model_len,
|
||||
)
|
||||
@@ -880,6 +908,69 @@ class SpecDecodeBaseProposer:
|
||||
next_token_ids, dtype=torch.int32, device=self.input_ids.device
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
def eagle_prepare_next_token_padded(
|
||||
self,
|
||||
sampled_token_ids, # [num_reqs, num_sampled_tokens_per_req]
|
||||
discard_request_mask, # [num_reqs]
|
||||
backup_next_token_ids, # [num_reqs]
|
||||
vocab_size,
|
||||
):
|
||||
"""
|
||||
PyTorch implementation of eagle_prepare_next_token_padded kernel.
|
||||
|
||||
Args:
|
||||
sampled_token_ids: Tensor of shape [num_reqs, num_sampled_tokens_per_req]
|
||||
containing sampled token ids (-1 for rejected tokens)
|
||||
discard_request_mask: Boolean tensor of shape [num_reqs] indicating
|
||||
which requests should be discarded
|
||||
backup_next_token_ids: Tensor of shape [num_reqs] containing backup
|
||||
token ids for when no valid tokens are found
|
||||
vocab_size: Vocabulary size for validity checking
|
||||
|
||||
Returns:
|
||||
next_token_ids: Tensor of shape [num_reqs] containing the next token
|
||||
to sample (last accepted token or backup)
|
||||
valid_sampled_tokens_count: Tensor of shape [num_reqs] containing the
|
||||
number of valid (1 + accepted) tokens
|
||||
"""
|
||||
num_reqs = sampled_token_ids.shape[0]
|
||||
num_sampled_tokens_per_req = sampled_token_ids.shape[1]
|
||||
|
||||
# Initialize output tensors
|
||||
next_token_ids = torch.empty(num_reqs, dtype=sampled_token_ids.dtype, device=sampled_token_ids.device)
|
||||
valid_sampled_tokens_count = torch.zeros(num_reqs, dtype=torch.int32, device=sampled_token_ids.device)
|
||||
|
||||
# Process each request
|
||||
for req_idx in range(num_reqs):
|
||||
if discard_request_mask[req_idx]:
|
||||
# Discarded request: use backup token and valid_count=0
|
||||
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
|
||||
valid_sampled_tokens_count[req_idx] = 0
|
||||
else:
|
||||
# Get sampled tokens for this request
|
||||
tokens = sampled_token_ids[req_idx]
|
||||
|
||||
# Find valid tokens (not -1 and within vocabulary range)
|
||||
is_valid = (tokens != -1) & (tokens < vocab_size)
|
||||
valid_count = is_valid.sum().item()
|
||||
|
||||
if valid_count > 0:
|
||||
# Find the last valid token index
|
||||
# Get indices where is_valid is True
|
||||
valid_indices = torch.where(is_valid)[0]
|
||||
last_valid_idx = valid_indices[-1].item()
|
||||
|
||||
# Get the token at that index
|
||||
last_valid_token = tokens[last_valid_idx]
|
||||
next_token_ids[req_idx] = last_valid_token
|
||||
else:
|
||||
# No valid tokens, use backup token
|
||||
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
|
||||
|
||||
valid_sampled_tokens_count[req_idx] = valid_count
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
@@ -910,31 +1001,15 @@ class SpecDecodeBaseProposer:
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
backup_tokens_gpu = self.backup_next_token_ids.gpu
|
||||
|
||||
batch_size, num_tokens = sampled_token_ids.shape
|
||||
device = sampled_token_ids.device
|
||||
|
||||
assert discard_request_mask.dtype == torch.bool
|
||||
assert backup_tokens_gpu.dtype == torch.int32
|
||||
|
||||
next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
|
||||
valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
|
||||
|
||||
# Kernel grid: one program per request (row)
|
||||
grid = (batch_size,)
|
||||
|
||||
# Find the next power of 2 for block sizes
|
||||
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
|
||||
eagle_prepare_next_token_padded_kernel[grid](
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = self.eagle_prepare_next_token_padded(
|
||||
sampled_token_ids,
|
||||
discard_request_mask,
|
||||
backup_tokens_gpu,
|
||||
next_token_ids,
|
||||
valid_sampled_tokens_count,
|
||||
gpu_input_batch.vocab_size,
|
||||
num_tokens,
|
||||
batch_size,
|
||||
sampled_token_ids.stride(0),
|
||||
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
@@ -974,6 +1049,8 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
|
||||
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
@@ -981,7 +1058,9 @@ class SpecDecodeBaseProposer:
|
||||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
seq_lens_np = common_attn_metadata.seq_lens_np,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
key_start_loc=key_start_loc,
|
||||
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
@@ -1014,9 +1093,7 @@ class SpecDecodeBaseProposer:
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
tree_attn_metadata_builder = self.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder()
|
||||
tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
|
||||
|
||||
total_num_drafts = self.cu_drafts_per_level[0]
|
||||
@@ -1092,10 +1169,11 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata=common_attn_metadata, draft_index=level + 1
|
||||
)
|
||||
|
||||
# Apply new attention metadata to all layers.
|
||||
# Apply new attention metadata to all draft layers.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(
|
||||
@@ -1131,7 +1209,7 @@ class SpecDecodeBaseProposer:
|
||||
# Run the model.
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
@@ -1209,6 +1287,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
|
||||
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
@@ -1261,6 +1340,7 @@ class SpecDecodeBaseProposer:
|
||||
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
|
||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
key_start_loc=key_start_loc,
|
||||
_seq_lens_cpu=new_seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
@@ -1289,7 +1369,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_model_tag("eagle_head"):
|
||||
model = get_model(
|
||||
vllm_config=self.vllm_config,
|
||||
vllm_config=self.draft_vllm_config,
|
||||
model_config=self.speculative_config.draft_model_config,
|
||||
load_config=self.speculative_config.draft_load_config,
|
||||
)
|
||||
@@ -1302,43 +1382,17 @@ class SpecDecodeBaseProposer:
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model
|
||||
target_indexer_layer_names = set(
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
).keys()
|
||||
)
|
||||
|
||||
self.model = self._get_model()
|
||||
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
- target_attn_layer_names
|
||||
# Find draft layers (attention layers added by draft model)
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.draft_vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
indexer_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
self._draft_attn_layer_names = (
|
||||
set(all_attn_layers.keys()) - target_attn_layer_names
|
||||
)
|
||||
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
|
||||
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
|
||||
self.indexer_layer_names = list(draft_indexer_layer_names)
|
||||
|
||||
if self.indexer_layer_names:
|
||||
first_layer = self.indexer_layer_names[0]
|
||||
self.draft_indexer_metadata_builder = (
|
||||
indexer_layers[first_layer]
|
||||
.get_attn_backend()
|
||||
.get_builder_cls()(
|
||||
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
|
||||
self.indexer_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.draft_indexer_metadata_builder = None
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
# Even if the target model is multimodal, we can also use
|
||||
@@ -1568,25 +1622,17 @@ class SpecDecodeBaseProposer:
|
||||
self.num_speculative_tokens if not is_graph_capturing else 1
|
||||
):
|
||||
if fwd_idx <= 1:
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
|
||||
)
|
||||
if use_cudagraphs:
|
||||
cudagraph_runtime_mode, batch_desc = (
|
||||
self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
num_input_tokens = batch_desc.num_tokens
|
||||
else:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
num_input_tokens = num_tokens_dp_padded
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
)
|
||||
|
||||
# Make sure to use EAGLE's own buffer during cudagraph capture.
|
||||
if (
|
||||
self.attn_layer_names
|
||||
self._draft_attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and self.attn_layer_names[0] in slot_mappings
|
||||
and next(iter(self._draft_attn_layer_names)) in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
@@ -1594,7 +1640,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
@@ -1616,31 +1662,6 @@ class SpecDecodeBaseProposer:
|
||||
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
self.model(**kwargs)
|
||||
|
||||
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
|
||||
Returns:
|
||||
The metadata builders for EAGLE layers.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||
"""
|
||||
builder = None
|
||||
chosen_layer = self.attn_layer_names[0]
|
||||
|
||||
for kv_cache_group in self.runner.attn_groups:
|
||||
for attn_group in kv_cache_group:
|
||||
if chosen_layer in attn_group.layer_names:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
break
|
||||
if builder is not None:
|
||||
break
|
||||
|
||||
assert builder is not None, (
|
||||
"Failed to find attention metadata builder for EAGLE layers."
|
||||
)
|
||||
return builder
|
||||
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||
"""
|
||||
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
|
||||
@@ -1673,35 +1694,114 @@ class SpecDecodeBaseProposer:
|
||||
set(
|
||||
[
|
||||
kv_cache_groups[layer_name]
|
||||
for layer_name in self.attn_layer_names
|
||||
for layer_name in self._draft_attn_layer_names
|
||||
]
|
||||
)
|
||||
)
|
||||
== 1
|
||||
), "All drafting layers should belong to the same kv cache group"
|
||||
|
||||
def _pad_batch_across_dp(
|
||||
def initialize_attn_backend(
|
||||
self,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
) -> tuple[int, torch.Tensor]:
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
|
||||
!= CUDAGraphMode.NONE,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=None,
|
||||
num_scheduled_tokens_per_request=None,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kernel_block_sizes: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AttentionGroups for draft layers using kv_cache_config.
|
||||
Called from the model runner's initialize_metadata_builders.
|
||||
"""
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.draft_vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
|
||||
|
||||
num_tokens_dp_padded = num_tokens_padded
|
||||
if num_toks_across_dp is not None:
|
||||
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
|
||||
return num_tokens_dp_padded, num_toks_across_dp
|
||||
# Find which kv_cache_group the draft layers belong to
|
||||
self.validate_same_kv_cache_group(kv_cache_config)
|
||||
kv_cache_spec = None
|
||||
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
if self._draft_attn_layer_names & set(group.layer_names):
|
||||
self.kv_cache_gid = gid
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
break
|
||||
|
||||
attention_groups: dict[tuple[str, str], AttentionGroup] = {}
|
||||
if kv_cache_spec is not None:
|
||||
for layer_name in self._draft_attn_layer_names:
|
||||
attn_backend = all_attn_layers[layer_name].get_attn_backend()
|
||||
backend_key = attn_backend.full_cls_name()
|
||||
if backend_key not in attention_groups:
|
||||
layer_kv_cache_spec = kv_cache_spec
|
||||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
|
||||
layer_name
|
||||
]
|
||||
|
||||
kernel_block_size = (
|
||||
kernel_block_sizes[self.kv_cache_gid]
|
||||
if kernel_block_sizes is not None
|
||||
and self.kv_cache_gid < len(kernel_block_sizes)
|
||||
else None
|
||||
)
|
||||
attn_group = AttentionGroup(
|
||||
backend=attn_backend,
|
||||
layer_names=[layer_name],
|
||||
kv_cache_spec=layer_kv_cache_spec,
|
||||
kv_cache_group_id=self.kv_cache_gid,
|
||||
)
|
||||
attn_group.create_metadata_builders(
|
||||
self.draft_vllm_config,
|
||||
self.device,
|
||||
kernel_block_size=kernel_block_size,
|
||||
)
|
||||
attention_groups[backend_key] = attn_group
|
||||
else:
|
||||
attention_groups[backend_key].layer_names.append(layer_name)
|
||||
|
||||
self.draft_attn_groups = list(attention_groups.values())
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens,
|
||||
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
|
||||
)
|
||||
num_tokens_padded = batch_desc.num_tokens
|
||||
|
||||
# Extra coordination when running data-parallel since we need to
|
||||
# coordinate across ranks
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_tokens_across_dp = False, None
|
||||
if self.draft_vllm_config.parallel_config.data_parallel_size > 1:
|
||||
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.draft_vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode.value,
|
||||
)
|
||||
)
|
||||
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
|
||||
|
||||
# Extract DP-synced values
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.dp_rank
|
||||
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
|
||||
# Re-dispatch with DP padding so we have the correct
|
||||
# batch_descriptor
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_padded,
|
||||
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
|
||||
)
|
||||
# Assert to make sure the agreed upon token count is correct
|
||||
# otherwise num_tokens_across_dp will no-longer be valid
|
||||
assert batch_desc.num_tokens == num_tokens_padded
|
||||
num_tokens_across_dp[dp_rank] = num_tokens_padded
|
||||
|
||||
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
|
||||
|
||||
|
||||
class EagleProposer(SpecDecodeBaseProposer):
|
||||
|
||||
Reference in New Issue
Block a user