[Bugfix] Fixed an accuracy problem of sp with eagle3 (#5816)

### What this PR does / why we need it?
Fixed an accuracy problem when using eagle3 with sp.

The problem is described in
https://github.com/vllm-project/vllm-ascend/issues/5825.

It also adds a much more precise way to determine whether drafter should
use `sp` or not.

Also, it changes the `eager` of drafter to be a real `eager` in frontend
to avoid a `fx-graph` problem.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

For simpilicity, we test it as in
https://github.com/vllm-project/vllm-ascend/issues/5825.

And we get the same result of `eagle3` with `sp` disabled.

```text
--------------------------------------------------
total_num_output_tokens: 1000
num_drafts: 437
num_draft_tokens: 1311
num_accepted_tokens: 564
mean acceptance length: 2.29
--------------------------------------------------
acceptance at token 0: 0.62
acceptance at token 1: 0.40
acceptance at token 2: 0.27
acceptance at token 3: 0.00
acceptance at token 4: 0.00
acceptance at token 5: 0.00
```

* vLLM version: v0.13.0
* vLLM main:
2f4e6548ef

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-01-14 09:00:37 +08:00
committed by GitHub
parent e1bed43cff
commit 48ec97821a
7 changed files with 246 additions and 141 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext
from typing import Optional
from contextlib import contextmanager, nullcontext
from typing import Any, ContextManager, Optional
import numpy as np
import torch
@@ -8,7 +8,8 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import (get_pp_group, get_world_group,
from vllm.distributed.parallel_state import (get_pp_group, get_tp_group,
get_world_group,
init_model_parallel_group,
patch_tensor_parallel_group)
from vllm.forward_context import get_forward_context
@@ -42,12 +43,45 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.ops.triton.spec_decode.utils import \
prepare_inputs_padded_kernel
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.utils import shared_expert_dp_enabled
from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled
# Currently we will fix block size to a small one since `num_reqs` can't be too large
_PREPARE_INPUTS_BLOCK_SIZE = 4
# TODO: Remove it when the bug of fx-graph is solved
# patch vllm_config to be in CompilationMode.NONE temporarily
@contextmanager
def _maybe_eager_context(vllm_config):
raw_compilation_config_mode = vllm_config.compilation_config.mode
vllm_config.compilation_config.mode = CompilationMode.NONE
try:
yield
finally:
vllm_config.compilation_config.mode = raw_compilation_config_mode
# split hidden states along dimension of sequence
def split_inputs_tp_to_sp(hidden_states, out):
# tp and sp share the same group
group = get_tp_group()
world_size = group.world_size
rank = group.rank
num_tokens = hidden_states.shape[0]
# the size per rank after padded
padded_num_tokens_per_rank = (num_tokens + world_size - 1) // world_size
# compute the start and end of slice
start = padded_num_tokens_per_rank * rank
end = padded_num_tokens_per_rank * (rank + 1)
# copy only hidden_states in current rank
hidden_states_curr_rank = hidden_states[start:end]
out[:hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank
return out[:padded_num_tokens_per_rank]
class EagleProposer(VllmEagleProposer):
def __init__(self,
@@ -118,6 +152,11 @@ class EagleProposer(VllmEagleProposer):
else:
self.tp_group_context = nullcontext()
# TODO: Remove it when the bug of fx-graph is solved
self.maybe_eager_context: ContextManager[Any] = nullcontext()
if not self.use_cuda_graph and enable_sp(vllm_config):
self.maybe_eager_context = _maybe_eager_context(vllm_config)
def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
@@ -126,9 +165,10 @@ class EagleProposer(VllmEagleProposer):
get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys())
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
with self.maybe_eager_context:
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache).keys()
@@ -273,8 +313,10 @@ class EagleProposer(VllmEagleProposer):
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
if self.enable_shared_expert_dp:
model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
forward_context = get_forward_context()
if forward_context.sp_enabled:
model_previous_hidden_states = split_inputs_tp_to_sp(
model_previous_hidden_states,
model_previous_hidden_states)
self.model(
@@ -282,7 +324,6 @@ class EagleProposer(VllmEagleProposer):
positions=model_positions,
hidden_states=model_previous_hidden_states,
)
forward_context = get_forward_context()
if (forward_context.cudagraph_runtime_mode
== CUDAGraphMode.FULL
and not forward_context.capturing):
@@ -293,7 +334,7 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config,
)
if self.enable_shared_expert_dp:
if forward_context.sp_enabled:
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
model_previous_hidden_states, True)
@@ -383,19 +424,19 @@ class EagleProposer(VllmEagleProposer):
model_positions = self.positions[:num_input_tokens]
model_hidden_states = self.hidden_states[:num_input_tokens]
if self.enable_shared_expert_dp:
forward_context = get_forward_context()
if forward_context.sp_enabled:
# split hidden states along sequence dimension
# positions should not be split?
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
model_hidden_states)
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
model_hidden_states = split_inputs_tp_to_sp(
model_hidden_states, model_hidden_states)
last_hidden_states, hidden_states = self.model(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
# TODO: support mla in future.
update_attn_params(
@@ -405,7 +446,7 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config,
)
if self.enable_shared_expert_dp:
if forward_context.sp_enabled:
# merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)
@@ -536,19 +577,18 @@ class EagleProposer(VllmEagleProposer):
model_positions = self.positions[:input_batch_size]
model_hidden_states = self.hidden_states[:input_batch_size]
if self.enable_shared_expert_dp:
forward_context = get_forward_context()
if forward_context.sp_enabled:
# split hidden states along sequence dimension
# positions should not be split
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
model_hidden_states)
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
model_hidden_states = split_inputs_tp_to_sp(
model_hidden_states, model_hidden_states)
last_hidden_states, hidden_states = self.model(
input_ids=model_input_ids,
positions=model_positions,
hidden_states=model_hidden_states,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
update_attn_params(
self.update_stream,
@@ -557,7 +597,7 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config,
)
if self.enable_shared_expert_dp:
if forward_context.sp_enabled:
# merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True)