[Refactor][EAGLE] 4/N extract common methods from eagle and mtp (#5870)

### What this PR does / why we need it?
This PR aims to extract common methods from eagle_proposer and
mtp_proposer. This is a small step towards merging eagle and mtp.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

---------

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2026-01-15 10:24:35 +08:00
committed by GitHub
parent c11a05c4e1
commit ea01aeaab7
4 changed files with 109 additions and 123 deletions

View File

@@ -91,23 +91,7 @@ class EagleProposer(VllmEagleProposer):
super().__init__(vllm_config, device, runner)
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
# there is synchronization between mtp steps when enabling aclgraph,
# disable aclgraph when use async scheduling to avoid the
# synchronization overhead.
# NOTE: we need to set aclgraph_runtime_mode to None in both dummy_run
# and _propose.
self.use_cuda_graph = (
self.vllm_config.compilation_config.mode
== CompilationMode.VLLM_COMPILE
and not self.vllm_config.model_config.enforce_eager
and not self.use_async_scheduling
and not self.vllm_config.speculative_config.enforce_eager)
self.cudagraph_batch_sizes = list(
sorted(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
self.pcp_size = self.runner.pcp_size
self.decode_threshold = 1 + self.num_speculative_tokens
self.query_start_loc = self.runner._make_buffer(
self.runner.max_num_reqs + 1, dtype=torch.int32)
@@ -118,12 +102,11 @@ class EagleProposer(VllmEagleProposer):
self.enable_shared_expert_dp = shared_expert_dp_enabled()
self.pcp_size = self.runner.pcp_size
self.dcp_size = self.runner.dcp_size
self.pcp_rank = self.runner.pcp_rank
self.dcp_rank = self.runner.dcp_rank
self.use_aclgraph = self.runner._use_aclgraph()
self.full_indices = range(
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
@@ -131,6 +114,10 @@ class EagleProposer(VllmEagleProposer):
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
"index_topk")
self.use_cuda_graph = (self.runner._use_aclgraph()
and not self.speculative_config.enforce_eager
and not self.use_async_scheduling)
# TODO: Remove it when the bug of fx-graph is solved
self.maybe_eager_context: ContextManager[Any] = nullcontext()
if not self.use_cuda_graph and enable_sp(vllm_config):
@@ -158,8 +145,7 @@ class EagleProposer(VllmEagleProposer):
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
self.attn_layer_names = self.attn_layer_name
self.attn_layer_names = list(draft_attn_layer_names)
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
@@ -273,7 +259,7 @@ class EagleProposer(VllmEagleProposer):
attn_metadata_eagle = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
attn_metadata = {}
for layer_name in self.attn_layer_name:
for layer_name in self.attn_layer_names:
attn_metadata[layer_name] = attn_metadata_eagle
model_input_ids = self.input_ids[:num_tokens]
@@ -292,30 +278,22 @@ class EagleProposer(VllmEagleProposer):
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
forward_context = get_forward_context()
if forward_context.sp_enabled:
model_previous_hidden_states = split_inputs_tp_to_sp(
model_previous_hidden_states,
model_previous_hidden_states)
model_previous_hidden_states, model_positions = self.maybe_pad_and_reduce(
model_previous_hidden_states, model_positions)
self.model(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_previous_hidden_states,
)
forward_context = get_forward_context()
if (forward_context.cudagraph_runtime_mode
== CUDAGraphMode.FULL
and not forward_context.capturing):
update_attn_params(
self.update_stream,
forward_context,
num_tokens,
self.vllm_config,
)
self._update_full_graph_params(forward_context, num_tokens)
if forward_context.sp_enabled:
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
model_previous_hidden_states, True)
model_previous_hidden_states, model_positions, _ = self.maybe_all_gather_and_unpad(
model_previous_hidden_states, model_positions)
dummy_compute_logits(self.hidden_states)
@@ -362,7 +340,7 @@ class EagleProposer(VllmEagleProposer):
self.input_ids[last_token_indices] = next_token_ids
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
@@ -386,7 +364,7 @@ class EagleProposer(VllmEagleProposer):
# update global cos, sin
update_cos_sin(self.positions[:num_input_tokens])
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_name:
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
with set_ascend_forward_context(
per_layer_attn_metadata,
@@ -403,34 +381,27 @@ class EagleProposer(VllmEagleProposer):
model_positions = self.positions[:num_input_tokens]
model_hidden_states = self.hidden_states[:num_input_tokens]
forward_context = get_forward_context()
if forward_context.sp_enabled:
# split hidden states along sequence dimension
# positions should not be split?
model_hidden_states = split_inputs_tp_to_sp(
model_hidden_states, model_hidden_states)
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
model_hidden_states, model_positions)
last_hidden_states, hidden_states = self.model(
ret_hidden_states = self.model(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
# TODO: support mla in future.
update_attn_params(
self.update_stream,
forward_context,
num_input_tokens,
self.vllm_config,
)
self._update_full_graph_params(forward_context,
num_input_tokens)
if forward_context.sp_enabled:
# merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), True)
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
last_hidden_states, model_positions, hidden_states)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
@@ -453,7 +424,7 @@ class EagleProposer(VllmEagleProposer):
last_token_indices = self.arange[:batch_size]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
batch_size <= self.runner.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
@@ -556,32 +527,27 @@ class EagleProposer(VllmEagleProposer):
model_positions = self.positions[:input_batch_size]
model_hidden_states = self.hidden_states[:input_batch_size]
forward_context = get_forward_context()
if forward_context.sp_enabled:
# split hidden states along sequence dimension
# positions should not be split
model_hidden_states = split_inputs_tp_to_sp(
model_hidden_states, model_hidden_states)
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
model_hidden_states, model_positions)
last_hidden_states, hidden_states = self.model(
ret_hidden_states = self.model(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
)
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
update_attn_params(
self.update_stream,
forward_context,
input_batch_size,
self.vllm_config,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
if forward_context.sp_enabled:
# merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), True)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
self._update_full_graph_params(forward_context,
input_batch_size)
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
last_hidden_states, model_positions, hidden_states)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
@@ -948,3 +914,46 @@ class EagleProposer(VllmEagleProposer):
else:
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config)
def maybe_pad_and_reduce(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.method == "mtp":
if self.enable_shared_expert_dp:
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
hidden_states)
positions = positions.unsqueeze(-1)
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
positions = positions.squeeze(-1)
else:
forward_context = get_forward_context()
if forward_context.sp_enabled:
hidden_states = split_inputs_tp_to_sp(
hidden_states, hidden_states)
return hidden_states, positions
def maybe_all_gather_and_unpad(
self,
last_hidden_states: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if self.method == "mtp":
if self.enable_shared_expert_dp:
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
positions.contiguous(), True)
if hidden_states is not None:
hidden_states = last_hidden_states
else:
forward_context = get_forward_context()
if forward_context.sp_enabled:
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
if hidden_states is not None:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), True)
return last_hidden_states, positions, hidden_states