From 7725314b26b0ae40e9fc622bc5d705e9b5d11124 Mon Sep 17 00:00:00 2001 From: anon189Ty Date: Fri, 23 Jan 2026 08:37:02 +0800 Subject: [PATCH] [Feat] Merge the multi eagle graphs to one graph (#5940) ### What this PR does / why we need it? This PR merge all steps of draft model in fullgraph mode, to avoid the synchronize between each graph, reduce the bubble time. #### Key ideas: - The "model forward" of the step 0 (first step) and remaining steps are captured together as a "Callable", rather than capturing each model individually. - "update_attn_params" is moved outside the entire graph, meaning that all "attn_metadata" required by all steps are constructed before "replay", and the "attn_params" of all steps are updated at once. - Remove synchronization between the main model graph and draft model graph. #### Key params/functions: - params: draft_attn_metadatas, attn_metadata_multi_steps, slot_mapping_group - functions: _run_merged_draft, attn_update_stack_num_spec_norm, update_attn_params, _propose, dummy_run ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 Signed-off-by: anon189Ty --- tests/ut/compilation/test_acl_graph.py | 2 + tests/ut/spec_decode/test_eagle_proposer.py | 32 +- tests/ut/spec_decode/test_mtp_proposer.py | 1 + vllm_ascend/compilation/acl_graph.py | 42 +- vllm_ascend/spec_decode/eagle_proposer.py | 537 ++++++++++++-------- 5 files changed, 396 insertions(+), 218 deletions(-) diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index 0261971f..ce2cf592 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -295,6 +295,7 @@ class TestACLGraphWrapper(TestBase): mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + self.mock_forward_context.is_draft_model = False # Mock torch.npu.NPUGraph mock_npu_graph = MagicMock() @@ -366,6 +367,7 @@ class TestACLGraphWrapper(TestBase): mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool mock_get_forward_context.return_value = self.mock_forward_context self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + self.mock_forward_context.is_draft_model = False # Mock torch.npu.NPUGraph mock_npu_graph = MagicMock() diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index fda58137..a6d6ef85 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -20,6 +20,7 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.model_config = MagicMock() self.device = torch.device("cpu") self.runner = MagicMock() + self.runner.pin_memory = False self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 @@ -93,6 +94,23 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.scheduler_config.async_scheduling = True init_ascend_config(self.vllm_config) + proposer = EagleProposer(vllm_config=self.vllm_config, + device=self.device, + runner=self.runner) + + self.assertEqual(proposer.hidden_size, 2048) + self.assertTrue(proposer.use_cuda_graph) + self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) + + def test_initialization_mtp_full_graph_async(self): + self.vllm_config.speculative_config.method = "mtp" + self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048 + self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE + self.vllm_config.model_config.enforce_eager = False + self.vllm_config.speculative_config.enforce_eager = False + self.vllm_config.scheduler_config.async_scheduling = True + init_ascend_config(self.vllm_config) + proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -110,6 +128,7 @@ class TestEagleProposerLoadModel(TestBase): self.vllm_config.speculative_config.method = "eagle" self.device = torch.device("cpu") self.runner = MagicMock() + self.runner.pin_memory = False self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 @@ -252,6 +271,7 @@ class TestEagleProposerDummyRun(TestBase): self.runner = MagicMock() self.runner.pcp_size = 1 self.runner.dcp_size = 1 + self.runner.pin_memory = False self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 @@ -279,6 +299,7 @@ class TestEagleProposerDummyRun(TestBase): device=self.device, runner=self.runner) self.proposer.model = MagicMock() + self.proposer._runnable = MagicMock() self.proposer.update_stream = MagicMock() def tearDown(self): @@ -298,7 +319,7 @@ class TestEagleProposerDummyRun(TestBase): self.proposer.dummy_run(num_tokens=num_tokens, with_prefill=with_prefill) - self.assertTrue(self.proposer.model.call_count == 4) + self.assertTrue(self.proposer._runnable.call_count == 1) # cpu does not support parallel-group, let alone `sp` @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context", @@ -309,7 +330,7 @@ class TestEagleProposerDummyRun(TestBase): # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) - self.assertTrue(self.proposer.model.call_count == 4) + self.assertTrue(self.proposer._runnable.call_count == 1) @patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params") @patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") @@ -329,7 +350,7 @@ class TestEagleProposerDummyRun(TestBase): self.proposer.dummy_run(num_tokens=64, in_graph_capturing=True, aclgraph_runtime_mode=CUDAGraphMode.FULL) - self.assertTrue(self.proposer.model.call_count == 4) + self.assertTrue(self.proposer._runnable.call_count == 1) mock_update_attn_params.assert_not_called() self.proposer.use_cuda_graph = last_use_cuda_graph @@ -351,8 +372,8 @@ class TestEagleProposerDummyRun(TestBase): self.proposer.dummy_run(num_tokens=64, in_graph_capturing=False, aclgraph_runtime_mode=CUDAGraphMode.FULL) - self.assertTrue(self.proposer.model.call_count == 4) - self.assertTrue(mock_update_attn_params.call_count == 4) + self.assertTrue(self.proposer._runnable.call_count == 1) + self.assertTrue(mock_update_attn_params.call_count == 1) self.proposer.use_cuda_graph = last_use_cuda_graph @@ -369,6 +390,7 @@ class TestEagleProposerHelperMethods(TestBase): self.runner.input_batch.req_ids = [0, 1, 2] self.runner.arange_np = np.arange(10) self.runner.input_batch.num_reqs = 3 + self.runner.pin_memory = False self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index d2bb0533..e800a8d5 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -74,6 +74,7 @@ class TestMtpProposer: runner.max_num_reqs = 256 runner._use_aclgraph.return_value = False runner.reserved_mc2_mask = None + runner.pin_memory = False return runner @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index dbfbc1bb..56613ac9 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -191,7 +191,15 @@ class ACLGraphWrapper: # before the grph replay of iteration i-1. # To ensure proper ordering, we must call synchronize here before replaying, # so that update_attn_params only executes after the previous graph replay has fully completed. - torch.npu.synchronize() + # If we do not in main model and in full-graph mode when using merge-eagle-graph, + # we do not need to synchronize. + use_eagle = ( + self.vllm_config.speculative_config.method in ("eagle", "eagle3") + if self.vllm_config.speculative_config + else False + ) + if self.runtime_mode != CUDAGraphMode.FULL or not forward_context.is_draft_model or not use_eagle: + torch.npu.synchronize() entry.aclgraph.replay() return entry.output @@ -247,18 +255,31 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape): event.record(update_stream) -def _update_attn_fia_params(update_stream, forward_context, runtime_shape): +def _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas=None): if forward_context.is_draft_model: graph_params = get_draft_graph_params() + attn_metadata = draft_attn_metadatas + attn_keys = list(attn_metadata[0].keys()) else: graph_params = get_graph_params() + attn_metadata = forward_context.attn_metadata + attn_keys = list(attn_metadata.keys()) # For Qwen3-next, since the kv_cache_config has already categorized # linear_attn and self_attn, the attn_metadata is first arranged with # self_attn followed by linear_attn. Therefore, using zip directly # filters out the update operations for linear_attn. + # TODO: We use a new variable `attn_keys` to ensure the loop count is + # correct after get by `zip` because of the new structure of the attn_metadata + # when running with the merged full eagle-graph. Should check it with Qwen3-next. + num_layers = len(attn_keys) + if num_layers == 0: + return + if forward_context.is_draft_model: + attn_keys = attn_keys * (len(graph_params.attn_params[runtime_shape]) // num_layers) + attn_count = 0 with torch.npu.stream(update_stream): for key, param, handle, event in zip( - forward_context.attn_metadata, + attn_keys, graph_params.attn_params[runtime_shape], graph_params.handles[runtime_shape], graph_params.events[runtime_shape], @@ -279,8 +300,15 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): softmax_lse, ) = param - seq_lens = forward_context.attn_metadata[key].seq_lens_list - actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q + if forward_context.is_draft_model: + draft_step = attn_count // num_layers + seq_lens = attn_metadata[draft_step][key].seq_lens_list + actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q + attn_count = attn_count + 1 + else: + seq_lens = attn_metadata[key].seq_lens_list + actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q + torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( query=query, @@ -304,11 +332,11 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape): event.record(update_stream) -def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config): +def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config, draft_attn_metadatas=None): if using_paged_attention(runtime_shape, vllm_config): _update_attn_pa_params(update_stream, forward_context, runtime_shape) else: - _update_attn_fia_params(update_stream, forward_context, runtime_shape) + _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas) def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config): diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index a8d657bd..789dd091 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import copy from contextlib import contextmanager, nullcontext -from typing import Any, ContextManager, Optional +from typing import Any, Callable, ContextManager, Optional, Union import numpy as np import torch @@ -84,6 +85,8 @@ def split_inputs_tp_to_sp(hidden_states, out): class EagleProposer(VllmEagleProposer): + _runnable: Union[ACLGraphWrapper, Callable] + def __init__(self, vllm_config: VllmConfig, device: torch.device, @@ -136,14 +139,29 @@ class EagleProposer(VllmEagleProposer): self.tp_group_context = nullcontext() self.use_cuda_graph = (self.runner._use_aclgraph() - and not self.speculative_config.enforce_eager - and not self.use_async_scheduling) + and not self.speculative_config.enforce_eager) + if self.method == "mtp": + self.use_cuda_graph = self.use_cuda_graph and not self.use_async_scheduling # TODO: Remove it when the bug of fx-graph is solved self.maybe_eager_context: ContextManager[Any] = nullcontext() if not self.use_cuda_graph and enable_sp(vllm_config): self.maybe_eager_context = _maybe_eager_context(vllm_config) + self.last_token_indices = torch.zeros( + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int32, + device=device) + slot_mapping_lens = self.runner.max_num_tokens + \ + 2 * self.pcp_size * self.runner.max_num_reqs + self.slot_mapping_group = [ + torch.zeros( + slot_mapping_lens, dtype=torch.int32, device=device, + pin_memory=self.runner.pin_memory) + for _ in range(self.num_speculative_tokens)] + + self._runnable = self._run_merged_draft + def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, @@ -166,7 +184,17 @@ class EagleProposer(VllmEagleProposer): draft_indexer_layer_names = indexer_layers - target_indexer_layer_names draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names assert len(draft_attn_layer_names) == 1 - self.attn_layer_names = list(draft_attn_layer_names) + self.attn_layer_names = list(sorted(draft_attn_layer_names)) + self.piece_all_attn_layer_name = [] + for _ in range(self.num_speculative_tokens): + self.piece_all_attn_layer_name.append([ + name for name in self.attn_layer_names]) + self.attn_layer_names = list(sorted(draft_attn_layer_names)) + + self.piece_all_attn_layer_name = [] + for _ in range(self.num_speculative_tokens): + self.piece_all_attn_layer_name.append([ + name for name in self.attn_layer_names]) if supports_multimodal(model): # handle multimodality @@ -236,9 +264,14 @@ class EagleProposer(VllmEagleProposer): if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( ) and self.use_cuda_graph: self.update_stream = torch.npu.Stream() - self.model = ACLGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + if self.method == "mtp": + self.model = ACLGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + else: + self._runnable = ACLGraphWrapper(self._run_merged_draft, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) def get_model(self) -> nn.Module: # get raw model out of the aclgraph wrapper. @@ -246,6 +279,11 @@ class EagleProposer(VllmEagleProposer): return self.model.unwrap() return self.model + def shallow_copy_metadata(self, attn_metadata): + # Currently, new objects will be assigned to the lists in attn_metadata + # when update. So we can use the shallow copy. + return copy.copy(attn_metadata) + @torch.inference_mode() def dummy_run(self, num_tokens: int, @@ -260,7 +298,7 @@ class EagleProposer(VllmEagleProposer): # update global cos, sin update_cos_sin(self._get_positions(num_tokens)) - attn_metadata = None + multi_steps_attn_metadata = [] if not self.use_cuda_graph: aclgraph_runtime_mode = CUDAGraphMode.NONE if aclgraph_runtime_mode == CUDAGraphMode.FULL and len( @@ -286,6 +324,7 @@ class EagleProposer(VllmEagleProposer): actual_seq_lengths_q=self.runner.actual_seq_lengths_q, block_table_tensor=self.runner.input_batch.block_table[0]. get_device_tensor()[:num_reqs], + # This is used to hold a position. slot_mapping=self.runner.input_batch.block_table[0]. slot_mapping.gpu, positions=self.runner.positions.gpu, @@ -295,46 +334,49 @@ class EagleProposer(VllmEagleProposer): ) builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata_eagle = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.ChunkedPrefill) - attn_metadata = {} - for layer_name in self.attn_layer_names: - attn_metadata[layer_name] = attn_metadata_eagle + # update the tensor's address for each step. + for draft_step in range(self.num_speculative_tokens): + common_attn_metadata = self.shallow_copy_metadata( + common_attn_metadata) + # Set the real slot_mapping. + common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step] + attn_metadata_eagle = builder.build_for_graph_capture( + common_attn_metadata, AscendAttentionState.ChunkedPrefill) + per_layer_attn_metadata = dict() + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata_eagle + multi_steps_attn_metadata.append(per_layer_attn_metadata) model_input_ids = self.input_ids[:num_tokens] model_positions = self._get_positions(num_tokens) model_previous_hidden_states = self.hidden_states[:num_tokens] - for i in range(self.num_speculative_tokens): - if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL: - aclgraph_runtime_mode = CUDAGraphMode.NONE - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_actual_tokens=0, - in_profile_run=is_profile, - batch_descriptor=batch_descriptor, - aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True): - model_previous_hidden_states, model_positions = self.maybe_pad_and_reduce( - model_previous_hidden_states, model_positions) + batch_size = num_tokens // (self.num_speculative_tokens + 1) + with set_ascend_forward_context( + multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, + self.vllm_config, + num_tokens=num_tokens, + num_actual_tokens=0, + in_profile_run=is_profile, + batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=aclgraph_runtime_mode, + is_draft_model=True): - self.model( - input_ids=model_input_ids, - positions=model_positions, - hidden_states=model_previous_hidden_states, - ) - forward_context = get_forward_context() - if (forward_context.cudagraph_runtime_mode - == CUDAGraphMode.FULL - and not forward_context.capturing): - self._update_full_graph_params(forward_context, num_tokens) - - model_previous_hidden_states, model_positions, _ = self.maybe_all_gather_and_unpad( - model_previous_hidden_states, model_positions) - - dummy_compute_logits(self.hidden_states) + self._runnable( + num_input_tokens=num_tokens, + batch_size=batch_size, + last_token_indices=self.last_token_indices[:batch_size], + # The target_position's address is same as the model_positions's + target_positions=model_positions, + inputs_embeds=None, + multi_steps_attn_metadata=multi_steps_attn_metadata, + ) + forward_context = get_forward_context() + if (forward_context.cudagraph_runtime_mode + == CUDAGraphMode.FULL + and not forward_context.capturing): + self._update_full_graph_params(forward_context, num_tokens, + multi_steps_attn_metadata) def _propose( self, @@ -408,17 +450,59 @@ class EagleProposer(VllmEagleProposer): inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] + # Update slot_mapping for different speculative. + # NOTE: Currently, we only remake the slot_mapping, because it's the + # only tensor which will be used in current FIA. + # Strictly speaking, `query_start_loc`, `seq_lens` should also have + # their memory allocated separately for each step just like `slot_mapping`. + slot_mapping_lens = num_input_tokens if num_input_tokens < \ + common_attn_metadata.slot_mapping.shape[0] else \ + common_attn_metadata.slot_mapping.shape[0] + self.slot_mapping_group[0][:slot_mapping_lens].copy_( + common_attn_metadata.slot_mapping[:slot_mapping_lens]) + self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1) + common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens] + # FIXME(woosuk): The below two ops cause synchronization. Optimize. builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) + # update global cos, sin update_cos_sin(self._get_positions(num_input_tokens)) - per_layer_attn_metadata = {} + + if self.uses_mrope: + used_update_positions = target_positions[:, last_token_indices] + else: + used_update_positions = target_positions[last_token_indices] + per_layer_attn_metadata = dict() + # The first step of speculative. for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata + multi_steps_attn_metadata = [per_layer_attn_metadata] + + # Copy the old attn_metadata and update + for draft_step in range(1, self.num_speculative_tokens): + common_attn_metadata, attn_metadata = \ + self.attn_update_stack_num_spec_norm( + draft_step, + attn_metadata, + common_attn_metadata, + batch_size, + num_input_tokens, + used_update_positions, + aclgraph_runtime_mode) + per_layer_attn_metadata = dict() + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + multi_steps_attn_metadata.append(per_layer_attn_metadata) + + last_token_indices_len = last_token_indices.shape[0] + self.last_token_indices[:last_token_indices_len].copy_( + last_token_indices) + with set_ascend_forward_context( - per_layer_attn_metadata, + multi_steps_attn_metadata[0], self.vllm_config, num_tokens=num_input_tokens, num_actual_tokens=num_tokens, @@ -426,34 +510,52 @@ class EagleProposer(VllmEagleProposer): aclgraph_runtime_mode=aclgraph_runtime_mode, is_draft_model=True): - # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. - # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. - model_input_ids = self.input_ids[:num_input_tokens] - model_positions = self._get_positions(num_input_tokens) - model_hidden_states = self.hidden_states[:num_input_tokens] - - model_hidden_states, model_positions = self.maybe_pad_and_reduce( - model_hidden_states, model_positions) - - ret_hidden_states = self.model( - input_ids=model_input_ids, - positions=model_positions, - hidden_states=model_hidden_states, - inputs_embeds = inputs_embeds - ) - if self.method == "mtp": - last_hidden_states = ret_hidden_states - hidden_states = last_hidden_states - else: - last_hidden_states, hidden_states = ret_hidden_states + draft_token_ids = self._runnable( + num_input_tokens=num_input_tokens, + batch_size=batch_size, + last_token_indices=self.last_token_indices[:last_token_indices_len], + target_positions=target_positions, + inputs_embeds=inputs_embeds, + multi_steps_attn_metadata=multi_steps_attn_metadata) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: self._update_full_graph_params(forward_context, - num_input_tokens) + num_input_tokens, + multi_steps_attn_metadata) + return draft_token_ids - last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( - last_hidden_states, model_positions, hidden_states) + def _run_merged_draft(self, + num_input_tokens, + batch_size, + last_token_indices, + target_positions, + inputs_embeds, + multi_steps_attn_metadata, + ) -> torch.Tensor: + # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. + # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. + model_input_ids = self.input_ids[:num_input_tokens] + model_positions = self._get_positions(num_input_tokens) + model_hidden_states = self.hidden_states[:num_input_tokens] + + model_hidden_states, model_positions = self.maybe_pad_and_reduce( + model_hidden_states, model_positions) + + ret_hidden_states = self.model( + input_ids=model_input_ids, + positions=model_positions, + hidden_states=model_hidden_states, + inputs_embeds = inputs_embeds, + ) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + hidden_states = last_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + + last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( + last_hidden_states, model_positions, hidden_states) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -477,53 +579,17 @@ class EagleProposer(VllmEagleProposer): hidden_states = hidden_states[last_token_indices] last_token_indices = self.arange[:batch_size] - if self.use_cuda_graph and \ - batch_size <= self.runner.cudagraph_batch_sizes[-1]: - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) - else: - input_batch_size = batch_size + input_batch_size = num_input_tokens - if self.use_cuda_graph: - aclgraph_runtime_mode, batch_descriptor = \ - self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora) - else: - aclgraph_runtime_mode = CUDAGraphMode.NONE - batch_descriptor = None + forward_context = get_forward_context() + forward_context.num_tokens = input_batch_size + forward_context.num_accept_tokens = batch_size - if ( - aclgraph_runtime_mode == CUDAGraphMode.FULL - and (pad_size := input_batch_size - batch_size) > 0 - ): - common_attn_metadata.num_reqs = input_batch_size - common_attn_metadata.block_table_tensor = self._pad_tensor( - common_attn_metadata.block_table_tensor, pad_size) - common_attn_metadata.seq_lens = self._pad_tensor( - common_attn_metadata.seq_lens, pad_size) - common_attn_metadata.seq_lens_cpu = self._pad_tensor( - common_attn_metadata.seq_lens_cpu, pad_size) - common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor( - common_attn_metadata.num_computed_tokens_cpu, pad_size) - common_attn_metadata.query_start_loc = self.arange[ - :input_batch_size + 1] - common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[:input_batch_size + 1]).clone() - else: - common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] - common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[:batch_size + 1]).clone() - - common_attn_metadata.num_actual_tokens = batch_size - common_attn_metadata.max_query_len = 1 - common_attn_metadata.decode_token_per_req = 1 - common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill - common_attn_metadata.graph_pad_size = -1 - common_attn_metadata.num_input_tokens = input_batch_size - - for now_speculative in range(self.num_speculative_tokens - 1): + for draft_step in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. - input_ids = draft_token_ids_tensor[now_speculative] + input_ids = draft_token_ids_tensor[draft_step] positions += 1 # NOTE(woosuk): We should handle the case where the draft model @@ -545,67 +611,6 @@ class EagleProposer(VllmEagleProposer): clamped_positions = torch.where(exceeds_max_model_len, 0, positions) - # For data integrity when async scheduling, we shouldn't use in place - # operations in case they are modified in next step's `prepare_input` - # of main model. - # Increment the sequence lengths. - common_attn_metadata.seq_lens[:batch_size] += 1 - # For the requests that exceed the max model length, we set the - # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens[:batch_size].masked_fill_( - exceeds_max_model_len, 1) - - common_attn_metadata.seq_lens_cpu[:batch_size] = ( - common_attn_metadata.seq_lens_cpu[:batch_size] + 1) - exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \ - self.max_model_len - common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_( - exceeds_mask, 1) - common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1 - if self.uses_mrope: - common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0]) - else: - common_attn_metadata.positions[:batch_size].copy_(clamped_positions) - if self.attn_metadata_builder is None: - attn_metadata_builder = self._get_attention_metadata_builder() - else: - attn_metadata_builder = self.attn_metadata_builder - block_size = attn_metadata_builder.kv_cache_spec.block_size - - # Compute the slot mapping. - if self.uses_mrope: - block_numbers = clamped_positions[0] // block_size - else: - block_numbers = (clamped_positions // block_size) - block_ids = attn_metadata.block_tables.gather( - dim=1, index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - if self.uses_mrope: - slot_mapping = (block_ids * block_size + - clamped_positions[0] % block_size) - else: - slot_mapping = (block_ids * block_size + - clamped_positions % block_size) - - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) - - common_attn_metadata.slot_mapping[:slot_mapping.shape[0]].copy_( - slot_mapping.to(torch.int32)) - common_attn_metadata.slot_mapping[slot_mapping.shape[0]:].fill_( - PADDING_SLOT_ID) - - # Rebuild attention metadata - attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore - common_attn_metadata=common_attn_metadata, - draft_index=now_speculative + 1, - ) - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata - # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) @@ -624,55 +629,175 @@ class EagleProposer(VllmEagleProposer): update_cos_sin(self._get_positions(input_batch_size)) # Run the model. - with set_ascend_forward_context( - per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size, - num_actual_tokens=batch_size, - batch_descriptor=batch_descriptor, - aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True): - # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. - # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. - model_input_ids = self.input_ids[:input_batch_size] - model_positions = self._get_positions(input_batch_size) - model_hidden_states = self.hidden_states[:input_batch_size] + # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. + # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. + model_input_ids = self.input_ids[:input_batch_size] + model_positions = self._get_positions(input_batch_size) + model_hidden_states = self.hidden_states[:input_batch_size] - model_hidden_states, model_positions = self.maybe_pad_and_reduce( - model_hidden_states, model_positions) + model_hidden_states, model_positions = self.maybe_pad_and_reduce( + model_hidden_states, model_positions) - ret_hidden_states = self.model( - input_ids=model_input_ids, - positions=model_positions, - hidden_states=model_hidden_states, - inputs_embeds = inputs_embeds - ) - if self.method == "mtp": - last_hidden_states = ret_hidden_states - hidden_states = last_hidden_states - else: - last_hidden_states, hidden_states = ret_hidden_states + forward_context.attn_metadata = multi_steps_attn_metadata[draft_step + 1] \ + if multi_steps_attn_metadata else None + ret_hidden_states = self.model( + input_ids=model_input_ids, + positions=model_positions, + hidden_states=model_hidden_states, + inputs_embeds = inputs_embeds, + ) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + hidden_states = last_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states - forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - self._update_full_graph_params(forward_context, - input_batch_size) - - last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( - last_hidden_states, model_positions, hidden_states) + last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( + last_hidden_states, model_positions, hidden_states) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size]) # TODO(wenlong): get more than one token for tree attention draft_token_ids = logits.argmax(dim=-1) - draft_token_ids_tensor[now_speculative + 1] = draft_token_ids + draft_token_ids_tensor[draft_step + 1] = draft_token_ids # [batch_size, num_speculative_tokens] draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) return draft_token_ids + def attn_update_stack_num_spec_norm(self, + # `draft_step` must start from `1`, no `0` + draft_step, + old_attn_metadata, + old_common_metadata, + batch_size, + input_batch_size, + used_update_positions, + aclgraph_runtime_mode): + + assert(draft_step > 0) + common_attn_metadata = self.shallow_copy_metadata(old_common_metadata) + + if draft_step == 1: + if ( + aclgraph_runtime_mode == CUDAGraphMode.FULL + and (pad_size := input_batch_size - batch_size) > 0 + ): + common_attn_metadata.num_reqs = input_batch_size + common_attn_metadata.block_table_tensor = self._pad_tensor( + common_attn_metadata.block_table_tensor, pad_size) + common_attn_metadata.seq_lens = self._pad_tensor( + common_attn_metadata.seq_lens, pad_size) + common_attn_metadata.seq_lens_cpu = self._pad_tensor( + common_attn_metadata.seq_lens_cpu, pad_size) + common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor( + common_attn_metadata.num_computed_tokens_cpu, pad_size) + common_attn_metadata.query_start_loc = self.arange[ + :input_batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:input_batch_size + 1]).clone() + else: + common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:batch_size + 1]).clone() + + common_attn_metadata.num_actual_tokens = batch_size + common_attn_metadata.max_query_len = 1 + common_attn_metadata.decode_token_per_req = 1 + common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + common_attn_metadata.graph_pad_size = -1 + common_attn_metadata.num_input_tokens = input_batch_size + + # The loop part + + used_update_positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + if self.uses_mrope: + exceeds_max_model_len = used_update_positions[ + 0] >= self.vllm_config.model_config.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where( + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(used_update_positions), used_update_positions) + else: + exceeds_max_model_len = used_update_positions >= \ + self.vllm_config.model_config.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, + used_update_positions) + + # For data integrity when async scheduling, we shouldn't use in place + # operations in case they are modified in next step's `prepare_input` + # of main model. + # Increment the sequence lengths. + common_attn_metadata.seq_lens[:batch_size] += 1 + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + common_attn_metadata.seq_lens[:batch_size].masked_fill_( + exceeds_max_model_len, 1) + + common_attn_metadata.seq_lens_cpu[:batch_size] = ( + common_attn_metadata.seq_lens_cpu[:batch_size] + 1) + exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \ + self.max_model_len + common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_( + exceeds_mask, 1) + common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1 + if self.uses_mrope: + common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0]) + else: + common_attn_metadata.positions[:batch_size].copy_(clamped_positions) + + if self.attn_metadata_builder is None: + attn_metadata_builder = self._get_attention_metadata_builder() + else: + attn_metadata_builder = self.attn_metadata_builder + block_size = attn_metadata_builder.kv_cache_spec.block_size + + # Compute the slot mapping. + if self.uses_mrope: + block_numbers = clamped_positions[0] // block_size + else: + block_numbers = (clamped_positions // block_size) + block_ids = old_attn_metadata.block_tables.gather( + dim=1, index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + if self.uses_mrope: + slot_mapping = (block_ids * block_size + + clamped_positions[0] % block_size) + else: + slot_mapping = (block_ids * block_size + + clamped_positions % block_size) + + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) + self.slot_mapping_group[draft_step][:slot_mapping.shape[0]].copy_( + slot_mapping.to(torch.int32)) + self.slot_mapping_group[draft_step][slot_mapping.shape[0]:].fill_( + PADDING_SLOT_ID) + # Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx] + common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][ + :slot_mapping.shape[0]] + + # Rebuild attention metadata + attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore + common_attn_metadata=common_attn_metadata, + draft_index=draft_step, + ) + + return common_attn_metadata, attn_metadata + def prepare_next_token_ids_padded( self, common_attn_metadata: CommonAttentionMetadata, @@ -1011,7 +1136,7 @@ class EagleProposer(VllmEagleProposer): return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens # update full-graph params for one spec token - def _update_full_graph_params(self, forward_context, num_tokens): + def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None): if self.vllm_config.model_config.use_mla: if self.pcp_size * self.dcp_size > 1: update_mla_attn_dcp_pcp_params(self.update_stream, @@ -1026,7 +1151,7 @@ class EagleProposer(VllmEagleProposer): num_tokens) else: update_attn_params(self.update_stream, forward_context, - num_tokens, self.vllm_config) + num_tokens, self.vllm_config, draft_attn_metadatas) # padding tensor into desired size def _pad_tensor(self, tensor, pad_size):