[Feat] Merge the multi eagle graphs to one graph (#5940)
### What this PR does / why we need it?
This PR merge all steps of draft model in fullgraph mode, to avoid the
synchronize between each graph, reduce the bubble time.
#### Key ideas:
- The "model forward" of the step 0 (first step) and remaining steps are
captured together as a "Callable", rather than capturing each model
individually.
- "update_attn_params" is moved outside the entire graph, meaning that
all "attn_metadata" required by all steps are constructed before
"replay", and the "attn_params" of all steps are updated at once.
- Remove synchronization between the main model graph and draft model
graph.
#### Key params/functions:
- params: draft_attn_metadatas, attn_metadata_multi_steps,
slot_mapping_group
- functions: _run_merged_draft, attn_update_stack_num_spec_norm,
update_attn_params, _propose, dummy_run
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -191,7 +191,15 @@ class ACLGraphWrapper:
|
||||
# before the grph replay of iteration i-1.
|
||||
# To ensure proper ordering, we must call synchronize here before replaying,
|
||||
# so that update_attn_params only executes after the previous graph replay has fully completed.
|
||||
torch.npu.synchronize()
|
||||
# If we do not in main model and in full-graph mode when using merge-eagle-graph,
|
||||
# we do not need to synchronize.
|
||||
use_eagle = (
|
||||
self.vllm_config.speculative_config.method in ("eagle", "eagle3")
|
||||
if self.vllm_config.speculative_config
|
||||
else False
|
||||
)
|
||||
if self.runtime_mode != CUDAGraphMode.FULL or not forward_context.is_draft_model or not use_eagle:
|
||||
torch.npu.synchronize()
|
||||
entry.aclgraph.replay()
|
||||
return entry.output
|
||||
|
||||
@@ -247,18 +255,31 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
||||
def _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas=None):
|
||||
if forward_context.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
attn_metadata = draft_attn_metadatas
|
||||
attn_keys = list(attn_metadata[0].keys())
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
attn_keys = list(attn_metadata.keys())
|
||||
# For Qwen3-next, since the kv_cache_config has already categorized
|
||||
# linear_attn and self_attn, the attn_metadata is first arranged with
|
||||
# self_attn followed by linear_attn. Therefore, using zip directly
|
||||
# filters out the update operations for linear_attn.
|
||||
# TODO: We use a new variable `attn_keys` to ensure the loop count is
|
||||
# correct after get by `zip` because of the new structure of the attn_metadata
|
||||
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
|
||||
num_layers = len(attn_keys)
|
||||
if num_layers == 0:
|
||||
return
|
||||
if forward_context.is_draft_model:
|
||||
attn_keys = attn_keys * (len(graph_params.attn_params[runtime_shape]) // num_layers)
|
||||
attn_count = 0
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
attn_keys,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
@@ -279,8 +300,15 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
||||
softmax_lse,
|
||||
) = param
|
||||
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens_list
|
||||
actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q
|
||||
if forward_context.is_draft_model:
|
||||
draft_step = attn_count // num_layers
|
||||
seq_lens = attn_metadata[draft_step][key].seq_lens_list
|
||||
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
|
||||
attn_count = attn_count + 1
|
||||
else:
|
||||
seq_lens = attn_metadata[key].seq_lens_list
|
||||
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
|
||||
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
@@ -304,11 +332,11 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config):
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config, draft_attn_metadatas=None):
|
||||
if using_paged_attention(runtime_shape, vllm_config):
|
||||
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
|
||||
else:
|
||||
_update_attn_fia_params(update_stream, forward_context, runtime_shape)
|
||||
_update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):
|
||||
|
||||
Reference in New Issue
Block a user