[Feat][Bugfix][main] Adapted SP to eagle3 (#5562)

### What this PR does / why we need it?

Adapted sp to eagle3.

There may still be some problems, e.g., accuracy in some scenes,
`sp`+`dp`...

We will fix them later.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

We tested it mainly in a new `e2e`.

```shell
pytest -s tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py::test_llama_qwen_eagle_acceptance
```

```text
.

=============================== warnings summary ===============================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============= 3 passed, 1 skipped, 2 warnings in 142.05s (0:02:22) =============
```

It passed.

- vLLM version: v0.13.0
- vLLM main:
7157596103

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-01-08 15:33:52 +08:00
committed by GitHub
parent d03cc9c456
commit ccbc5e2ba1
4 changed files with 217 additions and 15 deletions

View File

@@ -225,7 +225,6 @@ class EagleProposer(VllmEagleProposer):
decode_token_per_req=self.runner.decode_token_per_req,
max_seq_len=0,
)
dummy_compute_logits(self.hidden_states)
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_eagle = builder.build_for_graph_capture(
@@ -233,6 +232,10 @@ class EagleProposer(VllmEagleProposer):
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_eagle
model_input_ids = self.input_ids[:num_tokens]
model_positions = self.positions[:num_tokens]
model_previous_hidden_states = self.hidden_states[:num_tokens]
for i in range(self.num_speculative_tokens):
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
aclgraph_runtime_mode = CUDAGraphMode.NONE
@@ -245,12 +248,17 @@ class EagleProposer(VllmEagleProposer):
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
forward_context = get_forward_context()
if self.enable_shared_expert_dp:
model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
model_previous_hidden_states)
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_previous_hidden_states,
)
forward_context = get_forward_context()
if (forward_context.cudagraph_runtime_mode
== CUDAGraphMode.FULL
and not forward_context.capturing):
@@ -261,6 +269,12 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config,
)
if self.enable_shared_expert_dp:
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
model_previous_hidden_states, True)
dummy_compute_logits(self.hidden_states)
def _propose(
self,
# [num_tokens]
@@ -338,10 +352,24 @@ class EagleProposer(VllmEagleProposer):
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
model_input_ids = self.input_ids[:num_input_tokens]
model_positions = self.positions[:num_input_tokens]
model_hidden_states = self.hidden_states[:num_input_tokens]
if self.enable_shared_expert_dp:
# split hidden states along sequence dimension
# positions should not be split?
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
model_hidden_states)
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
@@ -352,6 +380,14 @@ class EagleProposer(VllmEagleProposer):
num_input_tokens,
self.vllm_config,
)
if self.enable_shared_expert_dp:
# merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), True)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids = logits.argmax(dim=-1)
@@ -470,10 +506,23 @@ class EagleProposer(VllmEagleProposer):
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
model_input_ids = self.input_ids[:input_batch_size]
model_positions = self.positions[:input_batch_size]
model_hidden_states = self.hidden_states[:input_batch_size]
if self.enable_shared_expert_dp:
# split hidden states along sequence dimension
# positions should not be split
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
model_hidden_states)
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
@@ -483,6 +532,14 @@ class EagleProposer(VllmEagleProposer):
input_batch_size,
self.vllm_config,
)
if self.enable_shared_expert_dp:
# merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), True)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])