[Feat][Bugfix][main] Adapted SP to eagle3 (#5562)

### What this PR does / why we need it?

Adapted sp to eagle3.

There may still be some problems, e.g., accuracy in some scenes,
`sp`+`dp`...

We will fix them later.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

We tested it mainly in a new `e2e`.

```shell
pytest -s tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py::test_llama_qwen_eagle_acceptance
```

```text
.

=============================== warnings summary ===============================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============= 3 passed, 1 skipped, 2 warnings in 142.05s (0:02:22) =============
```

It passed.

- vLLM version: v0.13.0
- vLLM main:
7157596103

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-01-08 15:33:52 +08:00
committed by GitHub
parent d03cc9c456
commit ccbc5e2ba1
4 changed files with 217 additions and 15 deletions

View File

@@ -34,6 +34,10 @@ BASELINES = {
"eagle3": [0.68, 0.40, 0.18],
}
BASELINES_SP = {
"eagle3": [0.68, 0.40, 0.18],
}
@pytest.fixture
def test_prompts():
@@ -371,3 +375,111 @@ def test_llama_qwen_eagle_acceptance(
print(f"golden: {golden}")
assert match
# TODO the function of sp in eagle3 is improving gradually,
# there are still problems when enable sp + dp and some unknown scenes.
# this e2e should also be improving gradually.
@pytest.mark.parametrize("method", ["eagle3"])
@pytest.mark.parametrize("num_speculative_tokens", [3])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_eagle3_sp_acceptance(
method: str,
num_speculative_tokens: int,
disable_padded_drafter_batch: bool,
async_scheduling: bool,
):
if disable_padded_drafter_batch and async_scheduling:
pytest.skip(
"skip disable_padded_drafter_batch=True and async_scheduling=True",
)
main_model_name = MODELS[method]["main"]
spec_model_name = MODELS[method]["spec"]
tokenizer = AutoTokenizer.from_pretrained(
main_model_name,
trust_remote_code=True,
)
sampling_params = SamplingParams(
temperature=0,
ignore_eos=False,
max_tokens=256,
)
# sp will only be enabled when query_lens > 1000
prompts = [
{
"role": "user",
"content": " " * 1000 + "Hello, my name is",
},
{
"role": "user",
"content": " " * 1000 + "The president of the United States is",
},
{
"role": "user",
"content": " " * 1000 + "The capital of France is",
},
{
"role": "user",
"content": " " * 1000 + "The future of AI is",
},
]
prompts = [
tokenizer.apply_chat_template(
[prompt],
tokenize=False,
add_generation_prompt=True,
) for prompt in prompts
]
speculative_config = {
"method": method,
"num_speculative_tokens": num_speculative_tokens,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
"model": spec_model_name,
}
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
with VllmRunner(
main_model_name,
enforce_eager=True,
max_model_len=8192,
disable_log_stats=False,
tensor_parallel_size=1,
max_num_seqs=256,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config=speculative_config,
compilation_config=compilation_config,
async_scheduling=async_scheduling,
) as llm:
_ = llm.generate(prompts, sampling_params)
metrics = llm.model.get_metrics()
num_drafts = 0
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos]
acceptance_per_pos = [
num_accepted_tokens / num_drafts
for num_accepted_tokens in num_accepted_tokens_per_pos
]
golden = BASELINES_SP[method]
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
if not match:
print(f"acceptance_per_pos: {acceptance_per_pos}")
print(f"golden: {golden}")
assert match

View File

@@ -275,6 +275,8 @@ class TestEagleProposerDummyRun(TestBase):
num_tokens = 32
with_prefill = False
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
self.proposer.enable_shared_expert_dp = False
self.proposer.dummy_run(num_tokens=num_tokens,
with_prefill=with_prefill)
@@ -284,6 +286,8 @@ class TestEagleProposerDummyRun(TestBase):
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
mock_context.return_value.__enter__.return_value = None
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
self.proposer.enable_shared_expert_dp = False
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
self.assertTrue(self.proposer.model.call_count == 4)
@@ -298,6 +302,8 @@ class TestEagleProposerDummyRun(TestBase):
mock_return_context.capturing = True
mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
self.proposer.enable_shared_expert_dp = False
self.proposer.dummy_run(num_tokens=64,
in_graph_capturing=True,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
@@ -316,6 +322,8 @@ class TestEagleProposerDummyRun(TestBase):
mock_return_context.capturing = False
mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
self.proposer.enable_shared_expert_dp = False
self.proposer.dummy_run(num_tokens=64,
in_graph_capturing=False,
aclgraph_runtime_mode=CUDAGraphMode.FULL)

View File

@@ -225,7 +225,6 @@ class EagleProposer(VllmEagleProposer):
decode_token_per_req=self.runner.decode_token_per_req,
max_seq_len=0,
)
dummy_compute_logits(self.hidden_states)
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_eagle = builder.build_for_graph_capture(
@@ -233,6 +232,10 @@ class EagleProposer(VllmEagleProposer):
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_eagle
model_input_ids = self.input_ids[:num_tokens]
model_positions = self.positions[:num_tokens]
model_previous_hidden_states = self.hidden_states[:num_tokens]
for i in range(self.num_speculative_tokens):
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
aclgraph_runtime_mode = CUDAGraphMode.NONE
@@ -245,12 +248,17 @@ class EagleProposer(VllmEagleProposer):
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
forward_context = get_forward_context()
if self.enable_shared_expert_dp:
model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
model_previous_hidden_states)
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_previous_hidden_states,
)
forward_context = get_forward_context()
if (forward_context.cudagraph_runtime_mode
== CUDAGraphMode.FULL
and not forward_context.capturing):
@@ -261,6 +269,12 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config,
)
if self.enable_shared_expert_dp:
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
model_previous_hidden_states, True)
dummy_compute_logits(self.hidden_states)
def _propose(
self,
# [num_tokens]
@@ -338,10 +352,24 @@ class EagleProposer(VllmEagleProposer):
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
model_input_ids = self.input_ids[:num_input_tokens]
model_positions = self.positions[:num_input_tokens]
model_hidden_states = self.hidden_states[:num_input_tokens]
if self.enable_shared_expert_dp:
# split hidden states along sequence dimension
# positions should not be split?
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
model_hidden_states)
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
@@ -352,6 +380,14 @@ class EagleProposer(VllmEagleProposer):
num_input_tokens,
self.vllm_config,
)
if self.enable_shared_expert_dp:
# merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), True)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids = logits.argmax(dim=-1)
@@ -470,10 +506,23 @@ class EagleProposer(VllmEagleProposer):
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
model_input_ids = self.input_ids[:input_batch_size]
model_positions = self.positions[:input_batch_size]
model_hidden_states = self.hidden_states[:input_batch_size]
if self.enable_shared_expert_dp:
# split hidden states along sequence dimension
# positions should not be split
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
model_hidden_states)
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
@@ -483,6 +532,14 @@ class EagleProposer(VllmEagleProposer):
input_batch_size,
self.vllm_config,
)
if self.enable_shared_expert_dp:
# merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), True)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])

View File

@@ -1056,6 +1056,33 @@ class NPUModelRunner(GPUModelRunner):
input_ids, inputs_embeds, intermediate_tensors,
max_num_scheduled_tokens)
# all-gather one hidden-states in sp scene
@staticmethod
def _all_gather_hidden_states(hidden_states):
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size
if pad_size > 0:
hidden_states = hidden_states[:-pad_size, :]
return hidden_states
# all-gather a list of hidden-states in sp scene
@staticmethod
def _all_gather_hidden_states_list(hidden_states_list):
return [
NPUModelRunner._all_gather_hidden_states(hidden_states)
for hidden_states in hidden_states_list
]
# all-gather hidden-states in last layer with aux-hidden-states in sp scene
@staticmethod
def _all_gather_hidden_states_and_aux(hidden_states):
if isinstance(hidden_states, tuple):
return (NPUModelRunner._all_gather_hidden_states(hidden_states[0]),
NPUModelRunner._all_gather_hidden_states_list(
hidden_states[1]))
return NPUModelRunner._all_gather_hidden_states(hidden_states)
def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens,
input_ids, positions,
intermediate_tensors,
@@ -1103,10 +1130,8 @@ class NPUModelRunner(GPUModelRunner):
if get_forward_context().sp_enabled and not isinstance(
hidden_states, IntermediateTensors):
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size
if pad_size > 0:
hidden_states = hidden_states[:-pad_size, :]
hidden_states = self._all_gather_hidden_states_and_aux(
hidden_states)
return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states(
hidden_states)