[Bugfix] Fixed an accuracy problem of sp with eagle3 (#5816)

### What this PR does / why we need it?
Fixed an accuracy problem when using eagle3 with sp.

The problem is described in
https://github.com/vllm-project/vllm-ascend/issues/5825.

It also adds a much more precise way to determine whether drafter should
use `sp` or not.

Also, it changes the `eager` of drafter to be a real `eager` in frontend
to avoid a `fx-graph` problem.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

For simpilicity, we test it as in
https://github.com/vllm-project/vllm-ascend/issues/5825.

And we get the same result of `eagle3` with `sp` disabled.

```text
--------------------------------------------------
total_num_output_tokens: 1000
num_drafts: 437
num_draft_tokens: 1311
num_accepted_tokens: 564
mean acceptance length: 2.29
--------------------------------------------------
acceptance at token 0: 0.62
acceptance at token 1: 0.40
acceptance at token 2: 0.27
acceptance at token 3: 0.00
acceptance at token 4: 0.00
acceptance at token 5: 0.00
```

* vLLM version: v0.13.0
* vLLM main:
2f4e6548ef

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-01-14 09:00:37 +08:00
committed by GitHub
parent e1bed43cff
commit 48ec97821a
7 changed files with 246 additions and 141 deletions

View File

@@ -217,7 +217,7 @@ jobs:
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
pytest -sv --durations=0 tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py
# To avoid oom, we need to run the test in a single process. # To avoid oom, we need to run the test in a single process.
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2 pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2

View File

@@ -0,0 +1,156 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Run `pytest tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py`.
from __future__ import annotations
import math
import os
import random
from typing import Any, Union
from unittest.mock import patch
import pytest
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.v1.metrics.reader import Counter, Vector
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = {
"eagle3": {
"main": "Qwen/Qwen3-8B",
"spec": "RedHatAI/Qwen3-8B-speculator.eagle3",
},
}
# NOTE: golden may change (eagle_proposer only runs in eager mode currently),
# thus please update it if ci fails but you have better acceptance
BASELINES_SP = {
"eagle3": [0.68, 0.40, 0.18],
}
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@pytest.mark.parametrize("method", ["eagle3"])
@pytest.mark.parametrize("num_speculative_tokens", [3])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_eagle3_sp_acceptance(
method: str,
num_speculative_tokens: int,
disable_padded_drafter_batch: bool,
async_scheduling: bool,
):
if disable_padded_drafter_batch and async_scheduling:
pytest.skip(
"skip disable_padded_drafter_batch=True and async_scheduling=True",
)
main_model_name = MODELS[method]["main"]
spec_model_name = MODELS[method]["spec"]
tokenizer = AutoTokenizer.from_pretrained(
main_model_name,
trust_remote_code=True,
)
sampling_params = SamplingParams(
temperature=0,
ignore_eos=False,
max_tokens=256,
)
# sp will only be enabled when query_lens > 1000
prompts = [
{
"role": "user",
"content": " " * 1000 + "Hello, my name is",
},
{
"role": "user",
"content": " " * 1000 + "The president of the United States is",
},
{
"role": "user",
"content": " " * 1000 + "The capital of France is",
},
{
"role": "user",
"content": " " * 1000 + "The future of AI is",
},
]
prompts = [
tokenizer.apply_chat_template(
[prompt],
tokenize=False,
add_generation_prompt=True,
) for prompt in prompts
]
speculative_config = {
"enforce_eager": True,
"method": method,
"num_speculative_tokens": num_speculative_tokens,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
"model": spec_model_name,
}
compilation_config = CompilationConfig(cudagraph_mode="FULL_DECODE_ONLY",
cudagraph_capture_sizes=[12])
with VllmRunner(
main_model_name,
enforce_eager=True,
max_model_len=8192,
disable_log_stats=False,
tensor_parallel_size=2,
max_num_seqs=256,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config=speculative_config,
compilation_config=compilation_config,
async_scheduling=async_scheduling,
) as llm:
_ = llm.generate(prompts, sampling_params)
metrics = llm.model.get_metrics()
num_drafts = 0
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos]
acceptance_per_pos = [
num_accepted_tokens / num_drafts
for num_accepted_tokens in num_accepted_tokens_per_pos
]
golden = BASELINES_SP[method]
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
if not match:
print(f"acceptance_per_pos: {acceptance_per_pos}")
print(f"golden: {golden}")
assert match

View File

@@ -34,11 +34,6 @@ BASELINES = {
"eagle3": [0.68, 0.40, 0.18], "eagle3": [0.68, 0.40, 0.18],
} }
BASELINES_SP = {
"eagle3": [0.68, 0.40, 0.18],
}
@pytest.fixture @pytest.fixture
def test_prompts(): def test_prompts():
prompt_types = ["repeat", "sentence"] prompt_types = ["repeat", "sentence"]
@@ -381,111 +376,3 @@ def test_llama_qwen_eagle_acceptance(
print(f"golden: {golden}") print(f"golden: {golden}")
assert match assert match
# TODO the function of sp in eagle3 is improving gradually,
# there are still problems when enable sp + dp and some unknown scenes.
# this e2e should also be improving gradually.
@pytest.mark.parametrize("method", ["eagle3"])
@pytest.mark.parametrize("num_speculative_tokens", [3])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_eagle3_sp_acceptance(
method: str,
num_speculative_tokens: int,
disable_padded_drafter_batch: bool,
async_scheduling: bool,
):
if disable_padded_drafter_batch and async_scheduling:
pytest.skip(
"skip disable_padded_drafter_batch=True and async_scheduling=True",
)
main_model_name = MODELS[method]["main"]
spec_model_name = MODELS[method]["spec"]
tokenizer = AutoTokenizer.from_pretrained(
main_model_name,
trust_remote_code=True,
)
sampling_params = SamplingParams(
temperature=0,
ignore_eos=False,
max_tokens=256,
)
# sp will only be enabled when query_lens > 1000
prompts = [
{
"role": "user",
"content": " " * 1000 + "Hello, my name is",
},
{
"role": "user",
"content": " " * 1000 + "The president of the United States is",
},
{
"role": "user",
"content": " " * 1000 + "The capital of France is",
},
{
"role": "user",
"content": " " * 1000 + "The future of AI is",
},
]
prompts = [
tokenizer.apply_chat_template(
[prompt],
tokenize=False,
add_generation_prompt=True,
) for prompt in prompts
]
speculative_config = {
"method": method,
"num_speculative_tokens": num_speculative_tokens,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
"model": spec_model_name,
}
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
with VllmRunner(
main_model_name,
enforce_eager=True,
max_model_len=8192,
disable_log_stats=False,
tensor_parallel_size=1,
max_num_seqs=256,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config=speculative_config,
compilation_config=compilation_config,
async_scheduling=async_scheduling,
) as llm:
_ = llm.generate(prompts, sampling_params)
metrics = llm.model.get_metrics()
num_drafts = 0
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos]
acceptance_per_pos = [
num_accepted_tokens / num_drafts
for num_accepted_tokens in num_accepted_tokens_per_pos
]
golden = BASELINES_SP[method]
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
if not match:
print(f"acceptance_per_pos: {acceptance_per_pos}")
print(f"golden: {golden}")
assert match

View File

@@ -275,7 +275,9 @@ class TestEagleProposerDummyRun(TestBase):
self.mock_cpugpubuffer.stop() self.mock_cpugpubuffer.stop()
self.mock_supports_multimodal_inputs.stop() self.mock_supports_multimodal_inputs.stop()
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") # cpu does not support parallel-group, let alone `sp`
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.sp_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_basic(self, mock_context, mock_get_context): def test_dummy_run_basic(self, mock_context, mock_get_context):
num_tokens = 32 num_tokens = 32
@@ -288,7 +290,9 @@ class TestEagleProposerDummyRun(TestBase):
self.assertTrue(self.proposer.model.call_count == 4) self.assertTrue(self.proposer.model.call_count == 4)
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context") # cpu does not support parallel-group, let alone `sp`
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.sp_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_with_prefill(self, mock_context, mock_get_context): def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
mock_context.return_value.__enter__.return_value = None mock_context.return_value.__enter__.return_value = None
@@ -306,6 +310,8 @@ class TestEagleProposerDummyRun(TestBase):
mock_return_context = MagicMock() mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
mock_return_context.capturing = True mock_return_context.capturing = True
# cpu does not support parallel-group, let alone `sp`
mock_return_context.sp_enabled = False
mock_get_context.return_value = mock_return_context mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
@@ -326,6 +332,8 @@ class TestEagleProposerDummyRun(TestBase):
mock_return_context = MagicMock() mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
mock_return_context.capturing = False mock_return_context.capturing = False
# cpu does not support parallel-group, let alone `sp`
mock_return_context.sp_enabled = False
mock_get_context.return_value = mock_return_context mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`

View File

@@ -14,7 +14,7 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable, from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
get_ascend_device_type, has_layer_idx, get_ascend_device_type, has_layer_idx,
is_moe_model, is_drafter_moe_model, is_moe_model,
speculative_enable_dispatch_gmm_combine_decode) speculative_enable_dispatch_gmm_combine_decode)
@@ -73,7 +73,10 @@ def set_ascend_forward_context(
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods. # the performance may degrade due to the switching of communication methods.
mmrs_fusion = True mmrs_fusion = True
if is_moe_model(vllm_config): # main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) \
if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False mmrs_fusion = False
else: else:

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext from contextlib import contextmanager, nullcontext
from typing import Optional from typing import Any, ContextManager, Optional
import numpy as np import numpy as np
import torch import torch
@@ -8,7 +8,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import (get_pp_group, get_world_group, from vllm.distributed.parallel_state import (get_pp_group, get_tp_group,
get_world_group,
init_model_parallel_group, init_model_parallel_group,
patch_tensor_parallel_group) patch_tensor_parallel_group)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
@@ -42,12 +43,45 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.ops.triton.spec_decode.utils import \ from vllm_ascend.ops.triton.spec_decode.utils import \
prepare_inputs_padded_kernel prepare_inputs_padded_kernel
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.utils import shared_expert_dp_enabled from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled
# Currently we will fix block size to a small one since `num_reqs` can't be too large # Currently we will fix block size to a small one since `num_reqs` can't be too large
_PREPARE_INPUTS_BLOCK_SIZE = 4 _PREPARE_INPUTS_BLOCK_SIZE = 4
# TODO: Remove it when the bug of fx-graph is solved
# patch vllm_config to be in CompilationMode.NONE temporarily
@contextmanager
def _maybe_eager_context(vllm_config):
raw_compilation_config_mode = vllm_config.compilation_config.mode
vllm_config.compilation_config.mode = CompilationMode.NONE
try:
yield
finally:
vllm_config.compilation_config.mode = raw_compilation_config_mode
# split hidden states along dimension of sequence
def split_inputs_tp_to_sp(hidden_states, out):
# tp and sp share the same group
group = get_tp_group()
world_size = group.world_size
rank = group.rank
num_tokens = hidden_states.shape[0]
# the size per rank after padded
padded_num_tokens_per_rank = (num_tokens + world_size - 1) // world_size
# compute the start and end of slice
start = padded_num_tokens_per_rank * rank
end = padded_num_tokens_per_rank * (rank + 1)
# copy only hidden_states in current rank
hidden_states_curr_rank = hidden_states[start:end]
out[:hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank
return out[:padded_num_tokens_per_rank]
class EagleProposer(VllmEagleProposer): class EagleProposer(VllmEagleProposer):
def __init__(self, def __init__(self,
@@ -118,6 +152,11 @@ class EagleProposer(VllmEagleProposer):
else: else:
self.tp_group_context = nullcontext() self.tp_group_context = nullcontext()
# TODO: Remove it when the bug of fx-graph is solved
self.maybe_eager_context: ContextManager[Any] = nullcontext()
if not self.use_cuda_graph and enable_sp(vllm_config):
self.maybe_eager_context = _maybe_eager_context(vllm_config)
def load_model(self, model: nn.Module) -> None: def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, get_layers_from_vllm_config(self.vllm_config,
@@ -126,6 +165,7 @@ class EagleProposer(VllmEagleProposer):
get_layers_from_vllm_config(self.vllm_config, get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys()) DeepseekV32IndexerCache).keys())
with self.maybe_eager_context:
self.model = get_model(vllm_config=self.vllm_config, self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config. model_config=self.vllm_config.
speculative_config.draft_model_config) speculative_config.draft_model_config)
@@ -273,8 +313,10 @@ class EagleProposer(VllmEagleProposer):
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True): is_draft_model=True):
if self.enable_shared_expert_dp: forward_context = get_forward_context()
model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( if forward_context.sp_enabled:
model_previous_hidden_states = split_inputs_tp_to_sp(
model_previous_hidden_states,
model_previous_hidden_states) model_previous_hidden_states)
self.model( self.model(
@@ -282,7 +324,6 @@ class EagleProposer(VllmEagleProposer):
positions=model_positions, positions=model_positions,
hidden_states=model_previous_hidden_states, hidden_states=model_previous_hidden_states,
) )
forward_context = get_forward_context()
if (forward_context.cudagraph_runtime_mode if (forward_context.cudagraph_runtime_mode
== CUDAGraphMode.FULL == CUDAGraphMode.FULL
and not forward_context.capturing): and not forward_context.capturing):
@@ -293,7 +334,7 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config, self.vllm_config,
) )
if self.enable_shared_expert_dp: if forward_context.sp_enabled:
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
model_previous_hidden_states, True) model_previous_hidden_states, True)
@@ -383,19 +424,19 @@ class EagleProposer(VllmEagleProposer):
model_positions = self.positions[:num_input_tokens] model_positions = self.positions[:num_input_tokens]
model_hidden_states = self.hidden_states[:num_input_tokens] model_hidden_states = self.hidden_states[:num_input_tokens]
if self.enable_shared_expert_dp: forward_context = get_forward_context()
if forward_context.sp_enabled:
# split hidden states along sequence dimension # split hidden states along sequence dimension
# positions should not be split? # positions should not be split?
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( model_hidden_states = split_inputs_tp_to_sp(
model_hidden_states) model_hidden_states, model_hidden_states)
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=model_input_ids, input_ids=model_input_ids,
positions=model_positions, positions=model_positions,
hidden_states=model_hidden_states, hidden_states=model_hidden_states,
) )
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
# TODO: support mla in future. # TODO: support mla in future.
update_attn_params( update_attn_params(
@@ -405,7 +446,7 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config, self.vllm_config,
) )
if self.enable_shared_expert_dp: if forward_context.sp_enabled:
# merge hidden states along sequence dimension # merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True) last_hidden_states.contiguous(), True)
@@ -536,19 +577,18 @@ class EagleProposer(VllmEagleProposer):
model_positions = self.positions[:input_batch_size] model_positions = self.positions[:input_batch_size]
model_hidden_states = self.hidden_states[:input_batch_size] model_hidden_states = self.hidden_states[:input_batch_size]
if self.enable_shared_expert_dp: forward_context = get_forward_context()
if forward_context.sp_enabled:
# split hidden states along sequence dimension # split hidden states along sequence dimension
# positions should not be split # positions should not be split
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( model_hidden_states = split_inputs_tp_to_sp(
model_hidden_states) model_hidden_states, model_hidden_states)
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=model_input_ids, input_ids=model_input_ids,
positions=model_positions, positions=model_positions,
hidden_states=model_hidden_states, hidden_states=model_hidden_states,
) )
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
update_attn_params( update_attn_params(
self.update_stream, self.update_stream,
@@ -557,7 +597,7 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config, self.vllm_config,
) )
if self.enable_shared_expert_dp: if forward_context.sp_enabled:
# merge hidden states along sequence dimension # merge hidden states along sequence dimension
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
last_hidden_states.contiguous(), True) last_hidden_states.contiguous(), True)

View File

@@ -62,6 +62,7 @@ _ASCEND_CUSTOMOP_IS_REIGISTERED = False
_DEFAULT_BUFFER_SIZE = 200 _DEFAULT_BUFFER_SIZE = 200
_MIN_DP_BUFFER_SIZE = 50 _MIN_DP_BUFFER_SIZE = 50
_IS_MOE_MODEL = None _IS_MOE_MODEL = None
_IS_DRAFTER_MOE_MODEL = None
_IS_VL_MODEL = None _IS_VL_MODEL = None
_ENABLE_SP = None _ENABLE_SP = None
_HAS_LAYER_IDX = None _HAS_LAYER_IDX = None
@@ -842,6 +843,16 @@ def is_moe_model(vllm_config: VllmConfig):
return _IS_MOE_MODEL return _IS_MOE_MODEL
def is_drafter_moe_model(vllm_config: VllmConfig):
"""Checks if the drafter model is a MoE model by config"""
global _IS_DRAFTER_MOE_MODEL
if _IS_DRAFTER_MOE_MODEL is None:
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config \
.to_dict()
_IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs)
return _IS_DRAFTER_MOE_MODEL
def speculative_enable_dispatch_gmm_combine_decode( def speculative_enable_dispatch_gmm_combine_decode(
vllm_config: VllmConfig) -> bool: vllm_config: VllmConfig) -> bool:
if vllm_config.speculative_config is None: if vllm_config.speculative_config is None: