[Feat] Merge the multi eagle graphs to one graph (#5940)

### What this PR does / why we need it?
This PR merge all steps of draft model in fullgraph mode, to avoid the
synchronize between each graph, reduce the bubble time.

#### Key ideas:
- The "model forward" of the step 0 (first step) and remaining steps are
captured together as a "Callable", rather than capturing each model
individually.
- "update_attn_params" is moved outside the entire graph, meaning that
all "attn_metadata" required by all steps are constructed before
"replay", and the "attn_params" of all steps are updated at once.
- Remove synchronization between the main model graph and draft model
graph.

#### Key params/functions:
- params: draft_attn_metadatas, attn_metadata_multi_steps,
slot_mapping_group
- functions: _run_merged_draft, attn_update_stack_num_spec_norm,
update_attn_params, _propose, dummy_run

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2026-01-23 08:37:02 +08:00
committed by GitHub
parent 63d3921208
commit 7725314b26
5 changed files with 396 additions and 218 deletions

View File

@@ -295,6 +295,7 @@ class TestACLGraphWrapper(TestBase):
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False
# Mock torch.npu.NPUGraph # Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock() mock_npu_graph = MagicMock()
@@ -366,6 +367,7 @@ class TestACLGraphWrapper(TestBase):
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
mock_get_forward_context.return_value = self.mock_forward_context mock_get_forward_context.return_value = self.mock_forward_context
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
self.mock_forward_context.is_draft_model = False
# Mock torch.npu.NPUGraph # Mock torch.npu.NPUGraph
mock_npu_graph = MagicMock() mock_npu_graph = MagicMock()

View File

@@ -20,6 +20,7 @@ class TestEagleProposerInitialization(TestBase):
self.vllm_config.model_config = MagicMock() self.vllm_config.model_config = MagicMock()
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.runner = MagicMock() self.runner = MagicMock()
self.runner.pin_memory = False
self.vllm_config.cache_config.block_size = 16 self.vllm_config.cache_config.block_size = 16
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
@@ -93,6 +94,23 @@ class TestEagleProposerInitialization(TestBase):
self.vllm_config.scheduler_config.async_scheduling = True self.vllm_config.scheduler_config.async_scheduling = True
init_ascend_config(self.vllm_config) init_ascend_config(self.vllm_config)
proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
self.assertEqual(proposer.hidden_size, 2048)
self.assertTrue(proposer.use_cuda_graph)
self.assertEqual(proposer.hidden_states.shape, (1024, 2048))
def test_initialization_mtp_full_graph_async(self):
self.vllm_config.speculative_config.method = "mtp"
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
self.vllm_config.model_config.enforce_eager = False
self.vllm_config.speculative_config.enforce_eager = False
self.vllm_config.scheduler_config.async_scheduling = True
init_ascend_config(self.vllm_config)
proposer = EagleProposer(vllm_config=self.vllm_config, proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device, device=self.device,
runner=self.runner) runner=self.runner)
@@ -110,6 +128,7 @@ class TestEagleProposerLoadModel(TestBase):
self.vllm_config.speculative_config.method = "eagle" self.vllm_config.speculative_config.method = "eagle"
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.runner = MagicMock() self.runner = MagicMock()
self.runner.pin_memory = False
self.vllm_config.cache_config.block_size = 16 self.vllm_config.cache_config.block_size = 16
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
@@ -252,6 +271,7 @@ class TestEagleProposerDummyRun(TestBase):
self.runner = MagicMock() self.runner = MagicMock()
self.runner.pcp_size = 1 self.runner.pcp_size = 1
self.runner.dcp_size = 1 self.runner.dcp_size = 1
self.runner.pin_memory = False
self.vllm_config.cache_config.block_size = 16 self.vllm_config.cache_config.block_size = 16
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
@@ -279,6 +299,7 @@ class TestEagleProposerDummyRun(TestBase):
device=self.device, device=self.device,
runner=self.runner) runner=self.runner)
self.proposer.model = MagicMock() self.proposer.model = MagicMock()
self.proposer._runnable = MagicMock()
self.proposer.update_stream = MagicMock() self.proposer.update_stream = MagicMock()
def tearDown(self): def tearDown(self):
@@ -298,7 +319,7 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.dummy_run(num_tokens=num_tokens, self.proposer.dummy_run(num_tokens=num_tokens,
with_prefill=with_prefill) with_prefill=with_prefill)
self.assertTrue(self.proposer.model.call_count == 4) self.assertTrue(self.proposer._runnable.call_count == 1)
# cpu does not support parallel-group, let alone `sp` # cpu does not support parallel-group, let alone `sp`
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context", @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
@@ -309,7 +330,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
self.proposer.enable_shared_expert_dp = False self.proposer.enable_shared_expert_dp = False
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
self.assertTrue(self.proposer.model.call_count == 4) self.assertTrue(self.proposer._runnable.call_count == 1)
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params") @patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@@ -329,7 +350,7 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.dummy_run(num_tokens=64, self.proposer.dummy_run(num_tokens=64,
in_graph_capturing=True, in_graph_capturing=True,
aclgraph_runtime_mode=CUDAGraphMode.FULL) aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer.model.call_count == 4) self.assertTrue(self.proposer._runnable.call_count == 1)
mock_update_attn_params.assert_not_called() mock_update_attn_params.assert_not_called()
self.proposer.use_cuda_graph = last_use_cuda_graph self.proposer.use_cuda_graph = last_use_cuda_graph
@@ -351,8 +372,8 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.dummy_run(num_tokens=64, self.proposer.dummy_run(num_tokens=64,
in_graph_capturing=False, in_graph_capturing=False,
aclgraph_runtime_mode=CUDAGraphMode.FULL) aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer.model.call_count == 4) self.assertTrue(self.proposer._runnable.call_count == 1)
self.assertTrue(mock_update_attn_params.call_count == 4) self.assertTrue(mock_update_attn_params.call_count == 1)
self.proposer.use_cuda_graph = last_use_cuda_graph self.proposer.use_cuda_graph = last_use_cuda_graph
@@ -369,6 +390,7 @@ class TestEagleProposerHelperMethods(TestBase):
self.runner.input_batch.req_ids = [0, 1, 2] self.runner.input_batch.req_ids = [0, 1, 2]
self.runner.arange_np = np.arange(10) self.runner.arange_np = np.arange(10)
self.runner.input_batch.num_reqs = 3 self.runner.input_batch.num_reqs = 3
self.runner.pin_memory = False
self.vllm_config.cache_config.block_size = 16 self.vllm_config.cache_config.block_size = 16
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024

View File

@@ -74,6 +74,7 @@ class TestMtpProposer:
runner.max_num_reqs = 256 runner.max_num_reqs = 256
runner._use_aclgraph.return_value = False runner._use_aclgraph.return_value = False
runner.reserved_mc2_mask = None runner.reserved_mc2_mask = None
runner.pin_memory = False
return runner return runner
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")

View File

@@ -191,7 +191,15 @@ class ACLGraphWrapper:
# before the grph replay of iteration i-1. # before the grph replay of iteration i-1.
# To ensure proper ordering, we must call synchronize here before replaying, # To ensure proper ordering, we must call synchronize here before replaying,
# so that update_attn_params only executes after the previous graph replay has fully completed. # so that update_attn_params only executes after the previous graph replay has fully completed.
torch.npu.synchronize() # If we do not in main model and in full-graph mode when using merge-eagle-graph,
# we do not need to synchronize.
use_eagle = (
self.vllm_config.speculative_config.method in ("eagle", "eagle3")
if self.vllm_config.speculative_config
else False
)
if self.runtime_mode != CUDAGraphMode.FULL or not forward_context.is_draft_model or not use_eagle:
torch.npu.synchronize()
entry.aclgraph.replay() entry.aclgraph.replay()
return entry.output return entry.output
@@ -247,18 +255,31 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
event.record(update_stream) event.record(update_stream)
def _update_attn_fia_params(update_stream, forward_context, runtime_shape): def _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas=None):
if forward_context.is_draft_model: if forward_context.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
attn_metadata = draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
else: else:
graph_params = get_graph_params() graph_params = get_graph_params()
attn_metadata = forward_context.attn_metadata
attn_keys = list(attn_metadata.keys())
# For Qwen3-next, since the kv_cache_config has already categorized # For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with # linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly # self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn. # filters out the update operations for linear_attn.
# TODO: We use a new variable `attn_keys` to ensure the loop count is
# correct after get by `zip` because of the new structure of the attn_metadata
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
num_layers = len(attn_keys)
if num_layers == 0:
return
if forward_context.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[runtime_shape]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream): with torch.npu.stream(update_stream):
for key, param, handle, event in zip( for key, param, handle, event in zip(
forward_context.attn_metadata, attn_keys,
graph_params.attn_params[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape], graph_params.handles[runtime_shape],
graph_params.events[runtime_shape], graph_params.events[runtime_shape],
@@ -279,8 +300,15 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
softmax_lse, softmax_lse,
) = param ) = param
seq_lens = forward_context.attn_metadata[key].seq_lens_list if forward_context.is_draft_model:
actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
attn_count = attn_count + 1
else:
seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
query=query, query=query,
@@ -304,11 +332,11 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
event.record(update_stream) event.record(update_stream)
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config): def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config, draft_attn_metadatas=None):
if using_paged_attention(runtime_shape, vllm_config): if using_paged_attention(runtime_shape, vllm_config):
_update_attn_pa_params(update_stream, forward_context, runtime_shape) _update_attn_pa_params(update_stream, forward_context, runtime_shape)
else: else:
_update_attn_fia_params(update_stream, forward_context, runtime_shape) _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas)
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config): def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from typing import Any, ContextManager, Optional from typing import Any, Callable, ContextManager, Optional, Union
import numpy as np import numpy as np
import torch import torch
@@ -84,6 +85,8 @@ def split_inputs_tp_to_sp(hidden_states, out):
class EagleProposer(VllmEagleProposer): class EagleProposer(VllmEagleProposer):
_runnable: Union[ACLGraphWrapper, Callable]
def __init__(self, def __init__(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
@@ -136,14 +139,29 @@ class EagleProposer(VllmEagleProposer):
self.tp_group_context = nullcontext() self.tp_group_context = nullcontext()
self.use_cuda_graph = (self.runner._use_aclgraph() self.use_cuda_graph = (self.runner._use_aclgraph()
and not self.speculative_config.enforce_eager and not self.speculative_config.enforce_eager)
and not self.use_async_scheduling) if self.method == "mtp":
self.use_cuda_graph = self.use_cuda_graph and not self.use_async_scheduling
# TODO: Remove it when the bug of fx-graph is solved # TODO: Remove it when the bug of fx-graph is solved
self.maybe_eager_context: ContextManager[Any] = nullcontext() self.maybe_eager_context: ContextManager[Any] = nullcontext()
if not self.use_cuda_graph and enable_sp(vllm_config): if not self.use_cuda_graph and enable_sp(vllm_config):
self.maybe_eager_context = _maybe_eager_context(vllm_config) self.maybe_eager_context = _maybe_eager_context(vllm_config)
self.last_token_indices = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int32,
device=device)
slot_mapping_lens = self.runner.max_num_tokens + \
2 * self.pcp_size * self.runner.max_num_reqs
self.slot_mapping_group = [
torch.zeros(
slot_mapping_lens, dtype=torch.int32, device=device,
pin_memory=self.runner.pin_memory)
for _ in range(self.num_speculative_tokens)]
self._runnable = self._run_merged_draft
def load_model(self, model: nn.Module) -> None: def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, get_layers_from_vllm_config(self.vllm_config,
@@ -166,7 +184,17 @@ class EagleProposer(VllmEagleProposer):
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1 assert len(draft_attn_layer_names) == 1
self.attn_layer_names = list(draft_attn_layer_names) self.attn_layer_names = list(sorted(draft_attn_layer_names))
self.piece_all_attn_layer_name = []
for _ in range(self.num_speculative_tokens):
self.piece_all_attn_layer_name.append([
name for name in self.attn_layer_names])
self.attn_layer_names = list(sorted(draft_attn_layer_names))
self.piece_all_attn_layer_name = []
for _ in range(self.num_speculative_tokens):
self.piece_all_attn_layer_name.append([
name for name in self.attn_layer_names])
if supports_multimodal(model): if supports_multimodal(model):
# handle multimodality # handle multimodality
@@ -236,9 +264,14 @@ class EagleProposer(VllmEagleProposer):
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and self.use_cuda_graph: ) and self.use_cuda_graph:
self.update_stream = torch.npu.Stream() self.update_stream = torch.npu.Stream()
self.model = ACLGraphWrapper(self.model, if self.method == "mtp":
self.vllm_config, self.model = ACLGraphWrapper(self.model,
runtime_mode=CUDAGraphMode.FULL) self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
else:
self._runnable = ACLGraphWrapper(self._run_merged_draft,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
# get raw model out of the aclgraph wrapper. # get raw model out of the aclgraph wrapper.
@@ -246,6 +279,11 @@ class EagleProposer(VllmEagleProposer):
return self.model.unwrap() return self.model.unwrap()
return self.model return self.model
def shallow_copy_metadata(self, attn_metadata):
# Currently, new objects will be assigned to the lists in attn_metadata
# when update. So we can use the shallow copy.
return copy.copy(attn_metadata)
@torch.inference_mode() @torch.inference_mode()
def dummy_run(self, def dummy_run(self,
num_tokens: int, num_tokens: int,
@@ -260,7 +298,7 @@ class EagleProposer(VllmEagleProposer):
# update global cos, sin # update global cos, sin
update_cos_sin(self._get_positions(num_tokens)) update_cos_sin(self._get_positions(num_tokens))
attn_metadata = None multi_steps_attn_metadata = []
if not self.use_cuda_graph: if not self.use_cuda_graph:
aclgraph_runtime_mode = CUDAGraphMode.NONE aclgraph_runtime_mode = CUDAGraphMode.NONE
if aclgraph_runtime_mode == CUDAGraphMode.FULL and len( if aclgraph_runtime_mode == CUDAGraphMode.FULL and len(
@@ -286,6 +324,7 @@ class EagleProposer(VllmEagleProposer):
actual_seq_lengths_q=self.runner.actual_seq_lengths_q, actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0]. block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs], get_device_tensor()[:num_reqs],
# This is used to hold a position.
slot_mapping=self.runner.input_batch.block_table[0]. slot_mapping=self.runner.input_batch.block_table[0].
slot_mapping.gpu, slot_mapping.gpu,
positions=self.runner.positions.gpu, positions=self.runner.positions.gpu,
@@ -295,46 +334,49 @@ class EagleProposer(VllmEagleProposer):
) )
builder = self.runner.attn_groups[0][0].get_metadata_builder() builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_eagle = builder.build_for_graph_capture( # update the tensor's address for each step.
common_attn_metadata, AscendAttentionState.ChunkedPrefill) for draft_step in range(self.num_speculative_tokens):
attn_metadata = {} common_attn_metadata = self.shallow_copy_metadata(
for layer_name in self.attn_layer_names: common_attn_metadata)
attn_metadata[layer_name] = attn_metadata_eagle # Set the real slot_mapping.
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step]
attn_metadata_eagle = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata_eagle
multi_steps_attn_metadata.append(per_layer_attn_metadata)
model_input_ids = self.input_ids[:num_tokens] model_input_ids = self.input_ids[:num_tokens]
model_positions = self._get_positions(num_tokens) model_positions = self._get_positions(num_tokens)
model_previous_hidden_states = self.hidden_states[:num_tokens] model_previous_hidden_states = self.hidden_states[:num_tokens]
for i in range(self.num_speculative_tokens):
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
aclgraph_runtime_mode = CUDAGraphMode.NONE
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_actual_tokens=0,
in_profile_run=is_profile,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
model_previous_hidden_states, model_positions = self.maybe_pad_and_reduce( batch_size = num_tokens // (self.num_speculative_tokens + 1)
model_previous_hidden_states, model_positions) with set_ascend_forward_context(
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
self.vllm_config,
num_tokens=num_tokens,
num_actual_tokens=0,
in_profile_run=is_profile,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
self.model( self._runnable(
input_ids=model_input_ids, num_input_tokens=num_tokens,
positions=model_positions, batch_size=batch_size,
hidden_states=model_previous_hidden_states, last_token_indices=self.last_token_indices[:batch_size],
) # The target_position's address is same as the model_positions's
forward_context = get_forward_context() target_positions=model_positions,
if (forward_context.cudagraph_runtime_mode inputs_embeds=None,
== CUDAGraphMode.FULL multi_steps_attn_metadata=multi_steps_attn_metadata,
and not forward_context.capturing): )
self._update_full_graph_params(forward_context, num_tokens) forward_context = get_forward_context()
if (forward_context.cudagraph_runtime_mode
model_previous_hidden_states, model_positions, _ = self.maybe_all_gather_and_unpad( == CUDAGraphMode.FULL
model_previous_hidden_states, model_positions) and not forward_context.capturing):
self._update_full_graph_params(forward_context, num_tokens,
dummy_compute_logits(self.hidden_states) multi_steps_attn_metadata)
def _propose( def _propose(
self, self,
@@ -408,17 +450,59 @@ class EagleProposer(VllmEagleProposer):
inputs_embeds = None inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
# Update slot_mapping for different speculative.
# NOTE: Currently, we only remake the slot_mapping, because it's the
# only tensor which will be used in current FIA.
# Strictly speaking, `query_start_loc`, `seq_lens` should also have
# their memory allocated separately for each step just like `slot_mapping`.
slot_mapping_lens = num_input_tokens if num_input_tokens < \
common_attn_metadata.slot_mapping.shape[0] else \
common_attn_metadata.slot_mapping.shape[0]
self.slot_mapping_group[0][:slot_mapping_lens].copy_(
common_attn_metadata.slot_mapping[:slot_mapping_lens])
self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1)
common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens]
# FIXME(woosuk): The below two ops cause synchronization. Optimize. # FIXME(woosuk): The below two ops cause synchronization. Optimize.
builder = self.runner.attn_groups[0][0].get_metadata_builder() builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata, attn_metadata = builder.build(0, common_attn_metadata,
self.runner.get_model()) self.runner.get_model())
# update global cos, sin # update global cos, sin
update_cos_sin(self._get_positions(num_input_tokens)) update_cos_sin(self._get_positions(num_input_tokens))
per_layer_attn_metadata = {}
if self.uses_mrope:
used_update_positions = target_positions[:, last_token_indices]
else:
used_update_positions = target_positions[last_token_indices]
per_layer_attn_metadata = dict()
# The first step of speculative.
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata = [per_layer_attn_metadata]
# Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = \
self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
last_token_indices_len = last_token_indices.shape[0]
self.last_token_indices[:last_token_indices_len].copy_(
last_token_indices)
with set_ascend_forward_context( with set_ascend_forward_context(
per_layer_attn_metadata, multi_steps_attn_metadata[0],
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
@@ -426,34 +510,52 @@ class EagleProposer(VllmEagleProposer):
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True): is_draft_model=True):
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. draft_token_ids = self._runnable(
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. num_input_tokens=num_input_tokens,
model_input_ids = self.input_ids[:num_input_tokens] batch_size=batch_size,
model_positions = self._get_positions(num_input_tokens) last_token_indices=self.last_token_indices[:last_token_indices_len],
model_hidden_states = self.hidden_states[:num_input_tokens] target_positions=target_positions,
inputs_embeds=inputs_embeds,
model_hidden_states, model_positions = self.maybe_pad_and_reduce( multi_steps_attn_metadata=multi_steps_attn_metadata)
model_hidden_states, model_positions)
ret_hidden_states = self.model(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
inputs_embeds = inputs_embeds
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
forward_context = get_forward_context() forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
self._update_full_graph_params(forward_context, self._update_full_graph_params(forward_context,
num_input_tokens) num_input_tokens,
multi_steps_attn_metadata)
return draft_token_ids
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( def _run_merged_draft(self,
last_hidden_states, model_positions, hidden_states) num_input_tokens,
batch_size,
last_token_indices,
target_positions,
inputs_embeds,
multi_steps_attn_metadata,
) -> torch.Tensor:
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
model_input_ids = self.input_ids[:num_input_tokens]
model_positions = self._get_positions(num_input_tokens)
model_hidden_states = self.hidden_states[:num_input_tokens]
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
model_hidden_states, model_positions)
ret_hidden_states = self.model(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
inputs_embeds = inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
last_hidden_states, model_positions, hidden_states)
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
@@ -477,53 +579,17 @@ class EagleProposer(VllmEagleProposer):
hidden_states = hidden_states[last_token_indices] hidden_states = hidden_states[last_token_indices]
last_token_indices = self.arange[:batch_size] last_token_indices = self.arange[:batch_size]
if self.use_cuda_graph and \ input_batch_size = num_input_tokens
batch_size <= self.runner.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
if self.use_cuda_graph: forward_context = get_forward_context()
aclgraph_runtime_mode, batch_descriptor = \ forward_context.num_tokens = input_batch_size
self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora) forward_context.num_accept_tokens = batch_size
else:
aclgraph_runtime_mode = CUDAGraphMode.NONE
batch_descriptor = None
if ( for draft_step in range(self.num_speculative_tokens - 1):
aclgraph_runtime_mode == CUDAGraphMode.FULL
and (pad_size := input_batch_size - batch_size) > 0
):
common_attn_metadata.num_reqs = input_batch_size
common_attn_metadata.block_table_tensor = self._pad_tensor(
common_attn_metadata.block_table_tensor, pad_size)
common_attn_metadata.seq_lens = self._pad_tensor(
common_attn_metadata.seq_lens, pad_size)
common_attn_metadata.seq_lens_cpu = self._pad_tensor(
common_attn_metadata.seq_lens_cpu, pad_size)
common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor(
common_attn_metadata.num_computed_tokens_cpu, pad_size)
common_attn_metadata.query_start_loc = self.arange[
:input_batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:input_batch_size + 1]).clone()
else:
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone()
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.decode_token_per_req = 1
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.graph_pad_size = -1
common_attn_metadata.num_input_tokens = input_batch_size
for now_speculative in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default. # tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_tensor[now_speculative] input_ids = draft_token_ids_tensor[draft_step]
positions += 1 positions += 1
# NOTE(woosuk): We should handle the case where the draft model # NOTE(woosuk): We should handle the case where the draft model
@@ -545,67 +611,6 @@ class EagleProposer(VllmEagleProposer):
clamped_positions = torch.where(exceeds_max_model_len, 0, clamped_positions = torch.where(exceeds_max_model_len, 0,
positions) positions)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths.
common_attn_metadata.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len, 1)
common_attn_metadata.seq_lens_cpu[:batch_size] = (
common_attn_metadata.seq_lens_cpu[:batch_size] + 1)
exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \
self.max_model_len
common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(
exceeds_mask, 1)
common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1
if self.uses_mrope:
common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0])
else:
common_attn_metadata.positions[:batch_size].copy_(clamped_positions)
if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
block_size = attn_metadata_builder.kv_cache_spec.block_size
# Compute the slot mapping.
if self.uses_mrope:
block_numbers = clamped_positions[0] // block_size
else:
block_numbers = (clamped_positions // block_size)
block_ids = attn_metadata.block_tables.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
if self.uses_mrope:
slot_mapping = (block_ids * block_size +
clamped_positions[0] % block_size)
else:
slot_mapping = (block_ids * block_size +
clamped_positions % block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
common_attn_metadata.slot_mapping[:slot_mapping.shape[0]].copy_(
slot_mapping.to(torch.int32))
common_attn_metadata.slot_mapping[slot_mapping.shape[0]:].fill_(
PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=now_speculative + 1,
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions) self._set_positions(batch_size, clamped_positions)
@@ -624,55 +629,175 @@ class EagleProposer(VllmEagleProposer):
update_cos_sin(self._get_positions(input_batch_size)) update_cos_sin(self._get_positions(input_batch_size))
# Run the model. # Run the model.
with set_ascend_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
num_actual_tokens=batch_size,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
model_input_ids = self.input_ids[:input_batch_size] model_input_ids = self.input_ids[:input_batch_size]
model_positions = self._get_positions(input_batch_size) model_positions = self._get_positions(input_batch_size)
model_hidden_states = self.hidden_states[:input_batch_size] model_hidden_states = self.hidden_states[:input_batch_size]
model_hidden_states, model_positions = self.maybe_pad_and_reduce( model_hidden_states, model_positions = self.maybe_pad_and_reduce(
model_hidden_states, model_positions) model_hidden_states, model_positions)
ret_hidden_states = self.model( forward_context.attn_metadata = multi_steps_attn_metadata[draft_step + 1] \
input_ids=model_input_ids, if multi_steps_attn_metadata else None
positions=model_positions, ret_hidden_states = self.model(
hidden_states=model_hidden_states, input_ids=model_input_ids,
inputs_embeds = inputs_embeds positions=model_positions,
) hidden_states=model_hidden_states,
if self.method == "mtp": inputs_embeds = inputs_embeds,
last_hidden_states = ret_hidden_states )
hidden_states = last_hidden_states if self.method == "mtp":
else: last_hidden_states = ret_hidden_states
last_hidden_states, hidden_states = ret_hidden_states hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
forward_context = get_forward_context() last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: last_hidden_states, model_positions, hidden_states)
self._update_full_graph_params(forward_context,
input_batch_size)
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
last_hidden_states, model_positions, hidden_states)
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size]) logits = self.model.compute_logits(last_hidden_states[:batch_size])
# TODO(wenlong): get more than one token for tree attention # TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_tensor[now_speculative + 1] = draft_token_ids draft_token_ids_tensor[draft_step + 1] = draft_token_ids
# [batch_size, num_speculative_tokens] # [batch_size, num_speculative_tokens]
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
return draft_token_ids return draft_token_ids
def attn_update_stack_num_spec_norm(self,
# `draft_step` must start from `1`, no `0`
draft_step,
old_attn_metadata,
old_common_metadata,
batch_size,
input_batch_size,
used_update_positions,
aclgraph_runtime_mode):
assert(draft_step > 0)
common_attn_metadata = self.shallow_copy_metadata(old_common_metadata)
if draft_step == 1:
if (
aclgraph_runtime_mode == CUDAGraphMode.FULL
and (pad_size := input_batch_size - batch_size) > 0
):
common_attn_metadata.num_reqs = input_batch_size
common_attn_metadata.block_table_tensor = self._pad_tensor(
common_attn_metadata.block_table_tensor, pad_size)
common_attn_metadata.seq_lens = self._pad_tensor(
common_attn_metadata.seq_lens, pad_size)
common_attn_metadata.seq_lens_cpu = self._pad_tensor(
common_attn_metadata.seq_lens_cpu, pad_size)
common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor(
common_attn_metadata.num_computed_tokens_cpu, pad_size)
common_attn_metadata.query_start_loc = self.arange[
:input_batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:input_batch_size + 1]).clone()
else:
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone()
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.decode_token_per_req = 1
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.graph_pad_size = -1
common_attn_metadata.num_input_tokens = input_batch_size
# The loop part
used_update_positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
if self.uses_mrope:
exceeds_max_model_len = used_update_positions[
0] >= self.vllm_config.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(
exceeds_max_model_len.unsqueeze(0),
torch.zeros_like(used_update_positions), used_update_positions)
else:
exceeds_max_model_len = used_update_positions >= \
self.vllm_config.model_config.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0,
used_update_positions)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths.
common_attn_metadata.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len, 1)
common_attn_metadata.seq_lens_cpu[:batch_size] = (
common_attn_metadata.seq_lens_cpu[:batch_size] + 1)
exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \
self.max_model_len
common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(
exceeds_mask, 1)
common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1
if self.uses_mrope:
common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0])
else:
common_attn_metadata.positions[:batch_size].copy_(clamped_positions)
if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
block_size = attn_metadata_builder.kv_cache_spec.block_size
# Compute the slot mapping.
if self.uses_mrope:
block_numbers = clamped_positions[0] // block_size
else:
block_numbers = (clamped_positions // block_size)
block_ids = old_attn_metadata.block_tables.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
if self.uses_mrope:
slot_mapping = (block_ids * block_size +
clamped_positions[0] % block_size)
else:
slot_mapping = (block_ids * block_size +
clamped_positions % block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
self.slot_mapping_group[draft_step][:slot_mapping.shape[0]].copy_(
slot_mapping.to(torch.int32))
self.slot_mapping_group[draft_step][slot_mapping.shape[0]:].fill_(
PADDING_SLOT_ID)
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][
:slot_mapping.shape[0]]
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=draft_step,
)
return common_attn_metadata, attn_metadata
def prepare_next_token_ids_padded( def prepare_next_token_ids_padded(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
@@ -1011,7 +1136,7 @@ class EagleProposer(VllmEagleProposer):
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
# update full-graph params for one spec token # update full-graph params for one spec token
def _update_full_graph_params(self, forward_context, num_tokens): def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None):
if self.vllm_config.model_config.use_mla: if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1: if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(self.update_stream, update_mla_attn_dcp_pcp_params(self.update_stream,
@@ -1026,7 +1151,7 @@ class EagleProposer(VllmEagleProposer):
num_tokens) num_tokens)
else: else:
update_attn_params(self.update_stream, forward_context, update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config) num_tokens, self.vllm_config, draft_attn_metadatas)
# padding tensor into desired size # padding tensor into desired size
def _pad_tensor(self, tensor, pad_size): def _pad_tensor(self, tensor, pad_size):