[Bugfix] Fixed an accuracy problem of sp with eagle3 (#5816)
### What this PR does / why we need it?
Fixed an accuracy problem when using eagle3 with sp.
The problem is described in
https://github.com/vllm-project/vllm-ascend/issues/5825.
It also adds a much more precise way to determine whether drafter should
use `sp` or not.
Also, it changes the `eager` of drafter to be a real `eager` in frontend
to avoid a `fx-graph` problem.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
For simpilicity, we test it as in
https://github.com/vllm-project/vllm-ascend/issues/5825.
And we get the same result of `eagle3` with `sp` disabled.
```text
--------------------------------------------------
total_num_output_tokens: 1000
num_drafts: 437
num_draft_tokens: 1311
num_accepted_tokens: 564
mean acceptance length: 2.29
--------------------------------------------------
acceptance at token 0: 0.62
acceptance at token 1: 0.40
acceptance at token 2: 0.27
acceptance at token 3: 0.00
acceptance at token 4: 0.00
acceptance at token 5: 0.00
```
* vLLM version: v0.13.0
* vLLM main:
2f4e6548ef
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
2
.github/workflows/_e2e_test.yaml
vendored
2
.github/workflows/_e2e_test.yaml
vendored
@@ -217,7 +217,7 @@ jobs:
|
||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py
|
||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py
|
||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
|
||||
|
||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py
|
||||
|
||||
# To avoid oom, we need to run the test in a single process.
|
||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2
|
||||
|
||||
156
tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py
Normal file
156
tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Run `pytest tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.v1.metrics.reader import Counter, Vector
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
MODELS = {
|
||||
"eagle3": {
|
||||
"main": "Qwen/Qwen3-8B",
|
||||
"spec": "RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||
},
|
||||
}
|
||||
|
||||
# NOTE: golden may change (eagle_proposer only runs in eager mode currently),
|
||||
# thus please update it if ci fails but you have better acceptance
|
||||
BASELINES_SP = {
|
||||
"eagle3": [0.68, 0.40, 0.18],
|
||||
}
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||
@pytest.mark.parametrize("method", ["eagle3"])
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [3])
|
||||
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
def test_eagle3_sp_acceptance(
|
||||
method: str,
|
||||
num_speculative_tokens: int,
|
||||
disable_padded_drafter_batch: bool,
|
||||
async_scheduling: bool,
|
||||
):
|
||||
if disable_padded_drafter_batch and async_scheduling:
|
||||
pytest.skip(
|
||||
"skip disable_padded_drafter_batch=True and async_scheduling=True",
|
||||
)
|
||||
|
||||
main_model_name = MODELS[method]["main"]
|
||||
spec_model_name = MODELS[method]["spec"]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
main_model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
ignore_eos=False,
|
||||
max_tokens=256,
|
||||
)
|
||||
|
||||
# sp will only be enabled when query_lens > 1000
|
||||
prompts = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": " " * 1000 + "Hello, my name is",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": " " * 1000 + "The president of the United States is",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": " " * 1000 + "The capital of France is",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": " " * 1000 + "The future of AI is",
|
||||
},
|
||||
]
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[prompt],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
) for prompt in prompts
|
||||
]
|
||||
|
||||
speculative_config = {
|
||||
"enforce_eager": True,
|
||||
"method": method,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
"disable_padded_drafter_batch": disable_padded_drafter_batch,
|
||||
"model": spec_model_name,
|
||||
}
|
||||
|
||||
compilation_config = CompilationConfig(cudagraph_mode="FULL_DECODE_ONLY",
|
||||
cudagraph_capture_sizes=[12])
|
||||
|
||||
with VllmRunner(
|
||||
main_model_name,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
disable_log_stats=False,
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=256,
|
||||
distributed_executor_backend="mp",
|
||||
gpu_memory_utilization=0.7,
|
||||
speculative_config=speculative_config,
|
||||
compilation_config=compilation_config,
|
||||
async_scheduling=async_scheduling,
|
||||
) as llm:
|
||||
_ = llm.generate(prompts, sampling_params)
|
||||
metrics = llm.model.get_metrics()
|
||||
|
||||
num_drafts = 0
|
||||
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
|
||||
for metric in metrics:
|
||||
if metric.name == "vllm:spec_decode_num_drafts":
|
||||
assert isinstance(metric, Counter)
|
||||
num_drafts += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||
assert isinstance(metric, Vector)
|
||||
for pos in range(len(metric.values)):
|
||||
num_accepted_tokens_per_pos[pos] += metric.values[pos]
|
||||
|
||||
acceptance_per_pos = [
|
||||
num_accepted_tokens / num_drafts
|
||||
for num_accepted_tokens in num_accepted_tokens_per_pos
|
||||
]
|
||||
golden = BASELINES_SP[method]
|
||||
|
||||
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
|
||||
if not match:
|
||||
print(f"acceptance_per_pos: {acceptance_per_pos}")
|
||||
print(f"golden: {golden}")
|
||||
|
||||
assert match
|
||||
@@ -34,11 +34,6 @@ BASELINES = {
|
||||
"eagle3": [0.68, 0.40, 0.18],
|
||||
}
|
||||
|
||||
BASELINES_SP = {
|
||||
"eagle3": [0.68, 0.40, 0.18],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
@@ -381,111 +376,3 @@ def test_llama_qwen_eagle_acceptance(
|
||||
print(f"golden: {golden}")
|
||||
|
||||
assert match
|
||||
|
||||
|
||||
# TODO the function of sp in eagle3 is improving gradually,
|
||||
# there are still problems when enable sp + dp and some unknown scenes.
|
||||
# this e2e should also be improving gradually.
|
||||
@pytest.mark.parametrize("method", ["eagle3"])
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [3])
|
||||
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
def test_eagle3_sp_acceptance(
|
||||
method: str,
|
||||
num_speculative_tokens: int,
|
||||
disable_padded_drafter_batch: bool,
|
||||
async_scheduling: bool,
|
||||
):
|
||||
if disable_padded_drafter_batch and async_scheduling:
|
||||
pytest.skip(
|
||||
"skip disable_padded_drafter_batch=True and async_scheduling=True",
|
||||
)
|
||||
|
||||
main_model_name = MODELS[method]["main"]
|
||||
spec_model_name = MODELS[method]["spec"]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
main_model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
ignore_eos=False,
|
||||
max_tokens=256,
|
||||
)
|
||||
|
||||
# sp will only be enabled when query_lens > 1000
|
||||
prompts = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": " " * 1000 + "Hello, my name is",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": " " * 1000 + "The president of the United States is",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": " " * 1000 + "The capital of France is",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": " " * 1000 + "The future of AI is",
|
||||
},
|
||||
]
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[prompt],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
) for prompt in prompts
|
||||
]
|
||||
|
||||
speculative_config = {
|
||||
"method": method,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
"disable_padded_drafter_batch": disable_padded_drafter_batch,
|
||||
"model": spec_model_name,
|
||||
}
|
||||
|
||||
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
|
||||
|
||||
with VllmRunner(
|
||||
main_model_name,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
disable_log_stats=False,
|
||||
tensor_parallel_size=1,
|
||||
max_num_seqs=256,
|
||||
distributed_executor_backend="mp",
|
||||
gpu_memory_utilization=0.7,
|
||||
speculative_config=speculative_config,
|
||||
compilation_config=compilation_config,
|
||||
async_scheduling=async_scheduling,
|
||||
) as llm:
|
||||
_ = llm.generate(prompts, sampling_params)
|
||||
metrics = llm.model.get_metrics()
|
||||
|
||||
num_drafts = 0
|
||||
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
|
||||
for metric in metrics:
|
||||
if metric.name == "vllm:spec_decode_num_drafts":
|
||||
assert isinstance(metric, Counter)
|
||||
num_drafts += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||
assert isinstance(metric, Vector)
|
||||
for pos in range(len(metric.values)):
|
||||
num_accepted_tokens_per_pos[pos] += metric.values[pos]
|
||||
|
||||
acceptance_per_pos = [
|
||||
num_accepted_tokens / num_drafts
|
||||
for num_accepted_tokens in num_accepted_tokens_per_pos
|
||||
]
|
||||
golden = BASELINES_SP[method]
|
||||
|
||||
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
|
||||
if not match:
|
||||
print(f"acceptance_per_pos: {acceptance_per_pos}")
|
||||
print(f"golden: {golden}")
|
||||
|
||||
assert match
|
||||
|
||||
@@ -275,7 +275,9 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.mock_cpugpubuffer.stop()
|
||||
self.mock_supports_multimodal_inputs.stop()
|
||||
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||
**{"return_value.sp_enabled": False})
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_basic(self, mock_context, mock_get_context):
|
||||
num_tokens = 32
|
||||
@@ -288,7 +290,9 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
|
||||
self.assertTrue(self.proposer.model.call_count == 4)
|
||||
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||
**{"return_value.sp_enabled": False})
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
|
||||
mock_context.return_value.__enter__.return_value = None
|
||||
@@ -306,6 +310,8 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
mock_return_context = MagicMock()
|
||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
mock_return_context.capturing = True
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
mock_return_context.sp_enabled = False
|
||||
mock_get_context.return_value = mock_return_context
|
||||
self.proposer.use_cuda_graph = True
|
||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||
@@ -326,6 +332,8 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
mock_return_context = MagicMock()
|
||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
mock_return_context.capturing = False
|
||||
# cpu does not support parallel-group, let alone `sp`
|
||||
mock_return_context.sp_enabled = False
|
||||
mock_get_context.return_value = mock_return_context
|
||||
self.proposer.use_cuda_graph = True
|
||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||
|
||||
@@ -14,7 +14,7 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
||||
get_ascend_device_type, has_layer_idx,
|
||||
is_moe_model,
|
||||
is_drafter_moe_model, is_moe_model,
|
||||
speculative_enable_dispatch_gmm_combine_decode)
|
||||
|
||||
|
||||
@@ -73,7 +73,10 @@ def set_ascend_forward_context(
|
||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||
# the performance may degrade due to the switching of communication methods.
|
||||
mmrs_fusion = True
|
||||
if is_moe_model(vllm_config):
|
||||
# main model and drafter model may have different architecture
|
||||
is_context_moe_model = is_drafter_moe_model(vllm_config) \
|
||||
if is_draft_model else is_moe_model(vllm_config)
|
||||
if is_context_moe_model:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, ContextManager, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -8,7 +8,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import (get_pp_group, get_world_group,
|
||||
from vllm.distributed.parallel_state import (get_pp_group, get_tp_group,
|
||||
get_world_group,
|
||||
init_model_parallel_group,
|
||||
patch_tensor_parallel_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -42,12 +43,45 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||
prepare_inputs_padded_kernel
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.utils import shared_expert_dp_enabled
|
||||
from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled
|
||||
|
||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||
|
||||
|
||||
# TODO: Remove it when the bug of fx-graph is solved
|
||||
# patch vllm_config to be in CompilationMode.NONE temporarily
|
||||
@contextmanager
|
||||
def _maybe_eager_context(vllm_config):
|
||||
raw_compilation_config_mode = vllm_config.compilation_config.mode
|
||||
vllm_config.compilation_config.mode = CompilationMode.NONE
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
vllm_config.compilation_config.mode = raw_compilation_config_mode
|
||||
|
||||
|
||||
# split hidden states along dimension of sequence
|
||||
def split_inputs_tp_to_sp(hidden_states, out):
|
||||
# tp and sp share the same group
|
||||
group = get_tp_group()
|
||||
|
||||
world_size = group.world_size
|
||||
rank = group.rank
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
# the size per rank after padded
|
||||
padded_num_tokens_per_rank = (num_tokens + world_size - 1) // world_size
|
||||
# compute the start and end of slice
|
||||
start = padded_num_tokens_per_rank * rank
|
||||
end = padded_num_tokens_per_rank * (rank + 1)
|
||||
|
||||
# copy only hidden_states in current rank
|
||||
hidden_states_curr_rank = hidden_states[start:end]
|
||||
out[:hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank
|
||||
return out[:padded_num_tokens_per_rank]
|
||||
|
||||
|
||||
class EagleProposer(VllmEagleProposer):
|
||||
|
||||
def __init__(self,
|
||||
@@ -118,6 +152,11 @@ class EagleProposer(VllmEagleProposer):
|
||||
else:
|
||||
self.tp_group_context = nullcontext()
|
||||
|
||||
# TODO: Remove it when the bug of fx-graph is solved
|
||||
self.maybe_eager_context: ContextManager[Any] = nullcontext()
|
||||
if not self.use_cuda_graph and enable_sp(vllm_config):
|
||||
self.maybe_eager_context = _maybe_eager_context(vllm_config)
|
||||
|
||||
def load_model(self, model: nn.Module) -> None:
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
@@ -126,9 +165,10 @@ class EagleProposer(VllmEagleProposer):
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
DeepseekV32IndexerCache).keys())
|
||||
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.
|
||||
speculative_config.draft_model_config)
|
||||
with self.maybe_eager_context:
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.
|
||||
speculative_config.draft_model_config)
|
||||
|
||||
indexer_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache).keys()
|
||||
@@ -273,8 +313,10 @@ class EagleProposer(VllmEagleProposer):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
model_previous_hidden_states = split_inputs_tp_to_sp(
|
||||
model_previous_hidden_states,
|
||||
model_previous_hidden_states)
|
||||
|
||||
self.model(
|
||||
@@ -282,7 +324,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
positions=model_positions,
|
||||
hidden_states=model_previous_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if (forward_context.cudagraph_runtime_mode
|
||||
== CUDAGraphMode.FULL
|
||||
and not forward_context.capturing):
|
||||
@@ -293,7 +334,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
if forward_context.sp_enabled:
|
||||
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
model_previous_hidden_states, True)
|
||||
|
||||
@@ -383,19 +424,19 @@ class EagleProposer(VllmEagleProposer):
|
||||
model_positions = self.positions[:num_input_tokens]
|
||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
# split hidden states along sequence dimension
|
||||
# positions should not be split?
|
||||
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
model_hidden_states)
|
||||
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
|
||||
model_hidden_states = split_inputs_tp_to_sp(
|
||||
model_hidden_states, model_hidden_states)
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# TODO: support mla in future.
|
||||
update_attn_params(
|
||||
@@ -405,7 +446,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
if forward_context.sp_enabled:
|
||||
# merge hidden states along sequence dimension
|
||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
last_hidden_states.contiguous(), True)
|
||||
@@ -536,19 +577,18 @@ class EagleProposer(VllmEagleProposer):
|
||||
model_positions = self.positions[:input_batch_size]
|
||||
model_hidden_states = self.hidden_states[:input_batch_size]
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.sp_enabled:
|
||||
# split hidden states along sequence dimension
|
||||
# positions should not be split?
|
||||
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
model_hidden_states)
|
||||
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
|
||||
model_hidden_states = split_inputs_tp_to_sp(
|
||||
model_hidden_states, model_hidden_states)
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
update_attn_params(
|
||||
self.update_stream,
|
||||
@@ -557,7 +597,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
if forward_context.sp_enabled:
|
||||
# merge hidden states along sequence dimension
|
||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
last_hidden_states.contiguous(), True)
|
||||
|
||||
@@ -62,6 +62,7 @@ _ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||
_DEFAULT_BUFFER_SIZE = 200
|
||||
_MIN_DP_BUFFER_SIZE = 50
|
||||
_IS_MOE_MODEL = None
|
||||
_IS_DRAFTER_MOE_MODEL = None
|
||||
_IS_VL_MODEL = None
|
||||
_ENABLE_SP = None
|
||||
_HAS_LAYER_IDX = None
|
||||
@@ -842,6 +843,16 @@ def is_moe_model(vllm_config: VllmConfig):
|
||||
return _IS_MOE_MODEL
|
||||
|
||||
|
||||
def is_drafter_moe_model(vllm_config: VllmConfig):
|
||||
"""Checks if the drafter model is a MoE model by config"""
|
||||
global _IS_DRAFTER_MOE_MODEL
|
||||
if _IS_DRAFTER_MOE_MODEL is None:
|
||||
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config \
|
||||
.to_dict()
|
||||
_IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs)
|
||||
return _IS_DRAFTER_MOE_MODEL
|
||||
|
||||
|
||||
def speculative_enable_dispatch_gmm_combine_decode(
|
||||
vllm_config: VllmConfig) -> bool:
|
||||
if vllm_config.speculative_config is None:
|
||||
|
||||
Reference in New Issue
Block a user