[Bugfix] Fixed an accuracy problem of sp with eagle3 (#5816)
### What this PR does / why we need it?
Fixed an accuracy problem when using eagle3 with sp.
The problem is described in
https://github.com/vllm-project/vllm-ascend/issues/5825.
It also adds a much more precise way to determine whether drafter should
use `sp` or not.
Also, it changes the `eager` of drafter to be a real `eager` in frontend
to avoid a `fx-graph` problem.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
For simpilicity, we test it as in
https://github.com/vllm-project/vllm-ascend/issues/5825.
And we get the same result of `eagle3` with `sp` disabled.
```text
--------------------------------------------------
total_num_output_tokens: 1000
num_drafts: 437
num_draft_tokens: 1311
num_accepted_tokens: 564
mean acceptance length: 2.29
--------------------------------------------------
acceptance at token 0: 0.62
acceptance at token 1: 0.40
acceptance at token 2: 0.27
acceptance at token 3: 0.00
acceptance at token 4: 0.00
acceptance at token 5: 0.00
```
* vLLM version: v0.13.0
* vLLM main:
2f4e6548ef
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, ContextManager, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -8,7 +8,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import (get_pp_group, get_world_group,
|
||||
from vllm.distributed.parallel_state import (get_pp_group, get_tp_group,
|
||||
get_world_group,
|
||||
init_model_parallel_group,
|
||||
patch_tensor_parallel_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -42,12 +43,45 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||
prepare_inputs_padded_kernel
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.utils import shared_expert_dp_enabled
|
||||
from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled
|
||||
|
||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||
|
||||
|
||||
# TODO: Remove it when the bug of fx-graph is solved
|
||||
# patch vllm_config to be in CompilationMode.NONE temporarily
|
||||
@contextmanager
|
||||
def _maybe_eager_context(vllm_config):
|
||||
raw_compilation_config_mode = vllm_config.compilation_config.mode
|
||||
vllm_config.compilation_config.mode = CompilationMode.NONE
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
vllm_config.compilation_config.mode = raw_compilation_config_mode
|
||||
|
||||
|
||||
# split hidden states along dimension of sequence
|
||||
def split_inputs_tp_to_sp(hidden_states, out):
|
||||
# tp and sp share the same group
|
||||
group = get_tp_group()
|
||||
|
||||
world_size = group.world_size
|
||||
rank = group.rank
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
# the size per rank after padded
|
||||
padded_num_tokens_per_rank = (num_tokens + world_size - 1) // world_size
|
||||
# compute the start and end of slice
|
||||
start = padded_num_tokens_per_rank * rank
|
||||
end = padded_num_tokens_per_rank * (rank + 1)
|
||||
|
||||
# copy only hidden_states in current rank
|
||||
hidden_states_curr_rank = hidden_states[start:end]
|
||||
out[:hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank
|
||||
return out[:padded_num_tokens_per_rank]
|
||||
|
||||
|
||||
class EagleProposer(VllmEagleProposer):
|
||||
|
||||
def __init__(self,
|
||||
@@ -118,6 +152,11 @@ class EagleProposer(VllmEagleProposer):
|
||||
else:
|
||||
self.tp_group_context = nullcontext()
|
||||
|
||||
# TODO: Remove it when the bug of fx-graph is solved
|
||||
self.maybe_eager_context: ContextManager[Any] = nullcontext()
|
||||
if not self.use_cuda_graph and enable_sp(vllm_config):
|
||||
self.maybe_eager_context = _maybe_eager_context(vllm_config)
|
||||
|
||||
def load_model(self, model: nn.Module) -> None:
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
@@ -126,9 +165,10 @@ class EagleProposer(VllmEagleProposer):
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
DeepseekV32IndexerCache).keys())
|
||||
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.
|
||||
speculative_config.draft_model_config)
|
||||
with self.maybe_eager_context:
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.
|
||||
speculative_config.draft_model_config)
|
||||
|
||||
indexer_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache).keys()
|
||||
@@ -273,8 +313,10 @@ class EagleProposer(VllmEagleProposer):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
model_previous_hidden_states = split_inputs_tp_to_sp(
|
||||
model_previous_hidden_states,
|
||||
model_previous_hidden_states)
|
||||
|
||||
self.model(
|
||||
@@ -282,7 +324,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
positions=model_positions,
|
||||
hidden_states=model_previous_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if (forward_context.cudagraph_runtime_mode
|
||||
== CUDAGraphMode.FULL
|
||||
and not forward_context.capturing):
|
||||
@@ -293,7 +334,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
if forward_context.sp_enabled:
|
||||
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
model_previous_hidden_states, True)
|
||||
|
||||
@@ -383,19 +424,19 @@ class EagleProposer(VllmEagleProposer):
|
||||
model_positions = self.positions[:num_input_tokens]
|
||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
# split hidden states along sequence dimension
|
||||
# positions should not be split?
|
||||
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
model_hidden_states)
|
||||
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
|
||||
model_hidden_states = split_inputs_tp_to_sp(
|
||||
model_hidden_states, model_hidden_states)
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# TODO: support mla in future.
|
||||
update_attn_params(
|
||||
@@ -405,7 +446,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
if forward_context.sp_enabled:
|
||||
# merge hidden states along sequence dimension
|
||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
last_hidden_states.contiguous(), True)
|
||||
@@ -536,19 +577,18 @@ class EagleProposer(VllmEagleProposer):
|
||||
model_positions = self.positions[:input_batch_size]
|
||||
model_hidden_states = self.hidden_states[:input_batch_size]
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
# split hidden states along sequence dimension
|
||||
# positions should not be split?
|
||||
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
model_hidden_states)
|
||||
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
|
||||
model_hidden_states = split_inputs_tp_to_sp(
|
||||
model_hidden_states, model_hidden_states)
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
update_attn_params(
|
||||
self.update_stream,
|
||||
@@ -557,7 +597,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
if forward_context.sp_enabled:
|
||||
# merge hidden states along sequence dimension
|
||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
last_hidden_states.contiguous(), True)
|
||||
|
||||
Reference in New Issue
Block a user