[Feat] Merge the multi eagle graphs to one graph (#5940)

### What this PR does / why we need it?
This PR merge all steps of draft model in fullgraph mode, to avoid the
synchronize between each graph, reduce the bubble time.

#### Key ideas:
- The "model forward" of the step 0 (first step) and remaining steps are
captured together as a "Callable", rather than capturing each model
individually.
- "update_attn_params" is moved outside the entire graph, meaning that
all "attn_metadata" required by all steps are constructed before
"replay", and the "attn_params" of all steps are updated at once.
- Remove synchronization between the main model graph and draft model
graph.

#### Key params/functions:
- params: draft_attn_metadatas, attn_metadata_multi_steps,
slot_mapping_group
- functions: _run_merged_draft, attn_update_stack_num_spec_norm,
update_attn_params, _propose, dummy_run

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2026-01-23 08:37:02 +08:00
committed by GitHub
parent 63d3921208
commit 7725314b26
5 changed files with 396 additions and 218 deletions

View File

@@ -191,7 +191,15 @@ class ACLGraphWrapper:
# before the grph replay of iteration i-1.
# To ensure proper ordering, we must call synchronize here before replaying,
# so that update_attn_params only executes after the previous graph replay has fully completed.
torch.npu.synchronize()
# If we do not in main model and in full-graph mode when using merge-eagle-graph,
# we do not need to synchronize.
use_eagle = (
self.vllm_config.speculative_config.method in ("eagle", "eagle3")
if self.vllm_config.speculative_config
else False
)
if self.runtime_mode != CUDAGraphMode.FULL or not forward_context.is_draft_model or not use_eagle:
torch.npu.synchronize()
entry.aclgraph.replay()
return entry.output
@@ -247,18 +255,31 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
event.record(update_stream)
def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
def _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas=None):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
attn_metadata = draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
else:
graph_params = get_graph_params()
attn_metadata = forward_context.attn_metadata
attn_keys = list(attn_metadata.keys())
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
# TODO: We use a new variable `attn_keys` to ensure the loop count is
# correct after get by `zip` because of the new structure of the attn_metadata
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
num_layers = len(attn_keys)
if num_layers == 0:
return
if forward_context.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[runtime_shape]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
attn_keys,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
@@ -279,8 +300,15 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
softmax_lse,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens_list
actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q
if forward_context.is_draft_model:
draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
attn_count = attn_count + 1
else:
seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
@@ -304,11 +332,11 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
event.record(update_stream)
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config):
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config, draft_attn_metadatas=None):
if using_paged_attention(runtime_shape, vllm_config):
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
else:
_update_attn_fia_params(update_stream, forward_context, runtime_shape)
_update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas)
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):