[Bugfix] Fixed an accuracy problem of sp with eagle3 (#5816)
### What this PR does / why we need it?
Fixed an accuracy problem when using eagle3 with sp.
The problem is described in
https://github.com/vllm-project/vllm-ascend/issues/5825.
It also adds a much more precise way to determine whether drafter should
use `sp` or not.
Also, it changes the `eager` of drafter to be a real `eager` in frontend
to avoid a `fx-graph` problem.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
For simpilicity, we test it as in
https://github.com/vllm-project/vllm-ascend/issues/5825.
And we get the same result of `eagle3` with `sp` disabled.
```text
--------------------------------------------------
total_num_output_tokens: 1000
num_drafts: 437
num_draft_tokens: 1311
num_accepted_tokens: 564
mean acceptance length: 2.29
--------------------------------------------------
acceptance at token 0: 0.62
acceptance at token 1: 0.40
acceptance at token 2: 0.27
acceptance at token 3: 0.00
acceptance at token 4: 0.00
acceptance at token 5: 0.00
```
* vLLM version: v0.13.0
* vLLM main:
2f4e6548ef
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
2
.github/workflows/_e2e_test.yaml
vendored
2
.github/workflows/_e2e_test.yaml
vendored
@@ -217,7 +217,7 @@ jobs:
|
|||||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py
|
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py
|
||||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py
|
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py
|
||||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
|
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
|
||||||
|
pytest -sv --durations=0 tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py
|
||||||
|
|
||||||
# To avoid oom, we need to run the test in a single process.
|
# To avoid oom, we need to run the test in a single process.
|
||||||
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2
|
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2
|
||||||
|
|||||||
156
tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py
Normal file
156
tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Run `pytest tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py`.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Any, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.config import CompilationConfig
|
||||||
|
from vllm.v1.metrics.reader import Counter, Vector
|
||||||
|
|
||||||
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"eagle3": {
|
||||||
|
"main": "Qwen/Qwen3-8B",
|
||||||
|
"spec": "RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# NOTE: golden may change (eagle_proposer only runs in eager mode currently),
|
||||||
|
# thus please update it if ci fails but you have better acceptance
|
||||||
|
BASELINES_SP = {
|
||||||
|
"eagle3": [0.68, 0.40, 0.18],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||||
|
@pytest.mark.parametrize("method", ["eagle3"])
|
||||||
|
@pytest.mark.parametrize("num_speculative_tokens", [3])
|
||||||
|
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
||||||
|
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||||
|
def test_eagle3_sp_acceptance(
|
||||||
|
method: str,
|
||||||
|
num_speculative_tokens: int,
|
||||||
|
disable_padded_drafter_batch: bool,
|
||||||
|
async_scheduling: bool,
|
||||||
|
):
|
||||||
|
if disable_padded_drafter_batch and async_scheduling:
|
||||||
|
pytest.skip(
|
||||||
|
"skip disable_padded_drafter_batch=True and async_scheduling=True",
|
||||||
|
)
|
||||||
|
|
||||||
|
main_model_name = MODELS[method]["main"]
|
||||||
|
spec_model_name = MODELS[method]["spec"]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
main_model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
ignore_eos=False,
|
||||||
|
max_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
# sp will only be enabled when query_lens > 1000
|
||||||
|
prompts = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": " " * 1000 + "Hello, my name is",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": " " * 1000 + "The president of the United States is",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": " " * 1000 + "The capital of France is",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": " " * 1000 + "The future of AI is",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
prompts = [
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
[prompt],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
) for prompt in prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
speculative_config = {
|
||||||
|
"enforce_eager": True,
|
||||||
|
"method": method,
|
||||||
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
|
"disable_padded_drafter_batch": disable_padded_drafter_batch,
|
||||||
|
"model": spec_model_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
compilation_config = CompilationConfig(cudagraph_mode="FULL_DECODE_ONLY",
|
||||||
|
cudagraph_capture_sizes=[12])
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
main_model_name,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_model_len=8192,
|
||||||
|
disable_log_stats=False,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
max_num_seqs=256,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
compilation_config=compilation_config,
|
||||||
|
async_scheduling=async_scheduling,
|
||||||
|
) as llm:
|
||||||
|
_ = llm.generate(prompts, sampling_params)
|
||||||
|
metrics = llm.model.get_metrics()
|
||||||
|
|
||||||
|
num_drafts = 0
|
||||||
|
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
|
||||||
|
for metric in metrics:
|
||||||
|
if metric.name == "vllm:spec_decode_num_drafts":
|
||||||
|
assert isinstance(metric, Counter)
|
||||||
|
num_drafts += metric.value
|
||||||
|
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||||
|
assert isinstance(metric, Vector)
|
||||||
|
for pos in range(len(metric.values)):
|
||||||
|
num_accepted_tokens_per_pos[pos] += metric.values[pos]
|
||||||
|
|
||||||
|
acceptance_per_pos = [
|
||||||
|
num_accepted_tokens / num_drafts
|
||||||
|
for num_accepted_tokens in num_accepted_tokens_per_pos
|
||||||
|
]
|
||||||
|
golden = BASELINES_SP[method]
|
||||||
|
|
||||||
|
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
|
||||||
|
if not match:
|
||||||
|
print(f"acceptance_per_pos: {acceptance_per_pos}")
|
||||||
|
print(f"golden: {golden}")
|
||||||
|
|
||||||
|
assert match
|
||||||
@@ -34,11 +34,6 @@ BASELINES = {
|
|||||||
"eagle3": [0.68, 0.40, 0.18],
|
"eagle3": [0.68, 0.40, 0.18],
|
||||||
}
|
}
|
||||||
|
|
||||||
BASELINES_SP = {
|
|
||||||
"eagle3": [0.68, 0.40, 0.18],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_prompts():
|
def test_prompts():
|
||||||
prompt_types = ["repeat", "sentence"]
|
prompt_types = ["repeat", "sentence"]
|
||||||
@@ -381,111 +376,3 @@ def test_llama_qwen_eagle_acceptance(
|
|||||||
print(f"golden: {golden}")
|
print(f"golden: {golden}")
|
||||||
|
|
||||||
assert match
|
assert match
|
||||||
|
|
||||||
|
|
||||||
# TODO the function of sp in eagle3 is improving gradually,
|
|
||||||
# there are still problems when enable sp + dp and some unknown scenes.
|
|
||||||
# this e2e should also be improving gradually.
|
|
||||||
@pytest.mark.parametrize("method", ["eagle3"])
|
|
||||||
@pytest.mark.parametrize("num_speculative_tokens", [3])
|
|
||||||
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
|
||||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
|
||||||
def test_eagle3_sp_acceptance(
|
|
||||||
method: str,
|
|
||||||
num_speculative_tokens: int,
|
|
||||||
disable_padded_drafter_batch: bool,
|
|
||||||
async_scheduling: bool,
|
|
||||||
):
|
|
||||||
if disable_padded_drafter_batch and async_scheduling:
|
|
||||||
pytest.skip(
|
|
||||||
"skip disable_padded_drafter_batch=True and async_scheduling=True",
|
|
||||||
)
|
|
||||||
|
|
||||||
main_model_name = MODELS[method]["main"]
|
|
||||||
spec_model_name = MODELS[method]["spec"]
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
main_model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0,
|
|
||||||
ignore_eos=False,
|
|
||||||
max_tokens=256,
|
|
||||||
)
|
|
||||||
|
|
||||||
# sp will only be enabled when query_lens > 1000
|
|
||||||
prompts = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": " " * 1000 + "Hello, my name is",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": " " * 1000 + "The president of the United States is",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": " " * 1000 + "The capital of France is",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": " " * 1000 + "The future of AI is",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
prompts = [
|
|
||||||
tokenizer.apply_chat_template(
|
|
||||||
[prompt],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
) for prompt in prompts
|
|
||||||
]
|
|
||||||
|
|
||||||
speculative_config = {
|
|
||||||
"method": method,
|
|
||||||
"num_speculative_tokens": num_speculative_tokens,
|
|
||||||
"disable_padded_drafter_batch": disable_padded_drafter_batch,
|
|
||||||
"model": spec_model_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
|
|
||||||
|
|
||||||
with VllmRunner(
|
|
||||||
main_model_name,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=8192,
|
|
||||||
disable_log_stats=False,
|
|
||||||
tensor_parallel_size=1,
|
|
||||||
max_num_seqs=256,
|
|
||||||
distributed_executor_backend="mp",
|
|
||||||
gpu_memory_utilization=0.7,
|
|
||||||
speculative_config=speculative_config,
|
|
||||||
compilation_config=compilation_config,
|
|
||||||
async_scheduling=async_scheduling,
|
|
||||||
) as llm:
|
|
||||||
_ = llm.generate(prompts, sampling_params)
|
|
||||||
metrics = llm.model.get_metrics()
|
|
||||||
|
|
||||||
num_drafts = 0
|
|
||||||
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
|
|
||||||
for metric in metrics:
|
|
||||||
if metric.name == "vllm:spec_decode_num_drafts":
|
|
||||||
assert isinstance(metric, Counter)
|
|
||||||
num_drafts += metric.value
|
|
||||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
|
||||||
assert isinstance(metric, Vector)
|
|
||||||
for pos in range(len(metric.values)):
|
|
||||||
num_accepted_tokens_per_pos[pos] += metric.values[pos]
|
|
||||||
|
|
||||||
acceptance_per_pos = [
|
|
||||||
num_accepted_tokens / num_drafts
|
|
||||||
for num_accepted_tokens in num_accepted_tokens_per_pos
|
|
||||||
]
|
|
||||||
golden = BASELINES_SP[method]
|
|
||||||
|
|
||||||
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
|
|
||||||
if not match:
|
|
||||||
print(f"acceptance_per_pos: {acceptance_per_pos}")
|
|
||||||
print(f"golden: {golden}")
|
|
||||||
|
|
||||||
assert match
|
|
||||||
|
|||||||
@@ -275,7 +275,9 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
self.mock_cpugpubuffer.stop()
|
self.mock_cpugpubuffer.stop()
|
||||||
self.mock_supports_multimodal_inputs.stop()
|
self.mock_supports_multimodal_inputs.stop()
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
# cpu does not support parallel-group, let alone `sp`
|
||||||
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||||
|
**{"return_value.sp_enabled": False})
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||||
def test_dummy_run_basic(self, mock_context, mock_get_context):
|
def test_dummy_run_basic(self, mock_context, mock_get_context):
|
||||||
num_tokens = 32
|
num_tokens = 32
|
||||||
@@ -288,7 +290,9 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
|
|
||||||
self.assertTrue(self.proposer.model.call_count == 4)
|
self.assertTrue(self.proposer.model.call_count == 4)
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
# cpu does not support parallel-group, let alone `sp`
|
||||||
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
|
||||||
|
**{"return_value.sp_enabled": False})
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||||
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
|
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
|
||||||
mock_context.return_value.__enter__.return_value = None
|
mock_context.return_value.__enter__.return_value = None
|
||||||
@@ -306,6 +310,8 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
mock_return_context = MagicMock()
|
mock_return_context = MagicMock()
|
||||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
mock_return_context.capturing = True
|
mock_return_context.capturing = True
|
||||||
|
# cpu does not support parallel-group, let alone `sp`
|
||||||
|
mock_return_context.sp_enabled = False
|
||||||
mock_get_context.return_value = mock_return_context
|
mock_get_context.return_value = mock_return_context
|
||||||
self.proposer.use_cuda_graph = True
|
self.proposer.use_cuda_graph = True
|
||||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||||
@@ -326,6 +332,8 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
mock_return_context = MagicMock()
|
mock_return_context = MagicMock()
|
||||||
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||||
mock_return_context.capturing = False
|
mock_return_context.capturing = False
|
||||||
|
# cpu does not support parallel-group, let alone `sp`
|
||||||
|
mock_return_context.sp_enabled = False
|
||||||
mock_get_context.return_value = mock_return_context
|
mock_get_context.return_value = mock_return_context
|
||||||
self.proposer.use_cuda_graph = True
|
self.proposer.use_cuda_graph = True
|
||||||
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import vllm_ascend.envs as envs_ascend
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
||||||
get_ascend_device_type, has_layer_idx,
|
get_ascend_device_type, has_layer_idx,
|
||||||
is_moe_model,
|
is_drafter_moe_model, is_moe_model,
|
||||||
speculative_enable_dispatch_gmm_combine_decode)
|
speculative_enable_dispatch_gmm_combine_decode)
|
||||||
|
|
||||||
|
|
||||||
@@ -73,7 +73,10 @@ def set_ascend_forward_context(
|
|||||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||||
# the performance may degrade due to the switching of communication methods.
|
# the performance may degrade due to the switching of communication methods.
|
||||||
mmrs_fusion = True
|
mmrs_fusion = True
|
||||||
if is_moe_model(vllm_config):
|
# main model and drafter model may have different architecture
|
||||||
|
is_context_moe_model = is_drafter_moe_model(vllm_config) \
|
||||||
|
if is_draft_model else is_moe_model(vllm_config)
|
||||||
|
if is_context_moe_model:
|
||||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||||
mmrs_fusion = False
|
mmrs_fusion = False
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from contextlib import nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from typing import Optional
|
from typing import Any, ContextManager, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -8,7 +8,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed.parallel_state import (get_pp_group, get_world_group,
|
from vllm.distributed.parallel_state import (get_pp_group, get_tp_group,
|
||||||
|
get_world_group,
|
||||||
init_model_parallel_group,
|
init_model_parallel_group,
|
||||||
patch_tensor_parallel_group)
|
patch_tensor_parallel_group)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@@ -42,12 +43,45 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
|||||||
from vllm_ascend.ops.triton.spec_decode.utils import \
|
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||||
prepare_inputs_padded_kernel
|
prepare_inputs_padded_kernel
|
||||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||||
from vllm_ascend.utils import shared_expert_dp_enabled
|
from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled
|
||||||
|
|
||||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Remove it when the bug of fx-graph is solved
|
||||||
|
# patch vllm_config to be in CompilationMode.NONE temporarily
|
||||||
|
@contextmanager
|
||||||
|
def _maybe_eager_context(vllm_config):
|
||||||
|
raw_compilation_config_mode = vllm_config.compilation_config.mode
|
||||||
|
vllm_config.compilation_config.mode = CompilationMode.NONE
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
vllm_config.compilation_config.mode = raw_compilation_config_mode
|
||||||
|
|
||||||
|
|
||||||
|
# split hidden states along dimension of sequence
|
||||||
|
def split_inputs_tp_to_sp(hidden_states, out):
|
||||||
|
# tp and sp share the same group
|
||||||
|
group = get_tp_group()
|
||||||
|
|
||||||
|
world_size = group.world_size
|
||||||
|
rank = group.rank
|
||||||
|
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
# the size per rank after padded
|
||||||
|
padded_num_tokens_per_rank = (num_tokens + world_size - 1) // world_size
|
||||||
|
# compute the start and end of slice
|
||||||
|
start = padded_num_tokens_per_rank * rank
|
||||||
|
end = padded_num_tokens_per_rank * (rank + 1)
|
||||||
|
|
||||||
|
# copy only hidden_states in current rank
|
||||||
|
hidden_states_curr_rank = hidden_states[start:end]
|
||||||
|
out[:hidden_states_curr_rank.shape[0]] = hidden_states_curr_rank
|
||||||
|
return out[:padded_num_tokens_per_rank]
|
||||||
|
|
||||||
|
|
||||||
class EagleProposer(VllmEagleProposer):
|
class EagleProposer(VllmEagleProposer):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -118,6 +152,11 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
else:
|
else:
|
||||||
self.tp_group_context = nullcontext()
|
self.tp_group_context = nullcontext()
|
||||||
|
|
||||||
|
# TODO: Remove it when the bug of fx-graph is solved
|
||||||
|
self.maybe_eager_context: ContextManager[Any] = nullcontext()
|
||||||
|
if not self.use_cuda_graph and enable_sp(vllm_config):
|
||||||
|
self.maybe_eager_context = _maybe_eager_context(vllm_config)
|
||||||
|
|
||||||
def load_model(self, model: nn.Module) -> None:
|
def load_model(self, model: nn.Module) -> None:
|
||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
get_layers_from_vllm_config(self.vllm_config,
|
get_layers_from_vllm_config(self.vllm_config,
|
||||||
@@ -126,6 +165,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
get_layers_from_vllm_config(self.vllm_config,
|
get_layers_from_vllm_config(self.vllm_config,
|
||||||
DeepseekV32IndexerCache).keys())
|
DeepseekV32IndexerCache).keys())
|
||||||
|
|
||||||
|
with self.maybe_eager_context:
|
||||||
self.model = get_model(vllm_config=self.vllm_config,
|
self.model = get_model(vllm_config=self.vllm_config,
|
||||||
model_config=self.vllm_config.
|
model_config=self.vllm_config.
|
||||||
speculative_config.draft_model_config)
|
speculative_config.draft_model_config)
|
||||||
@@ -273,8 +313,10 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
is_draft_model=True):
|
is_draft_model=True):
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
forward_context = get_forward_context()
|
||||||
model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
if forward_context.sp_enabled:
|
||||||
|
model_previous_hidden_states = split_inputs_tp_to_sp(
|
||||||
|
model_previous_hidden_states,
|
||||||
model_previous_hidden_states)
|
model_previous_hidden_states)
|
||||||
|
|
||||||
self.model(
|
self.model(
|
||||||
@@ -282,7 +324,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
positions=model_positions,
|
positions=model_positions,
|
||||||
hidden_states=model_previous_hidden_states,
|
hidden_states=model_previous_hidden_states,
|
||||||
)
|
)
|
||||||
forward_context = get_forward_context()
|
|
||||||
if (forward_context.cudagraph_runtime_mode
|
if (forward_context.cudagraph_runtime_mode
|
||||||
== CUDAGraphMode.FULL
|
== CUDAGraphMode.FULL
|
||||||
and not forward_context.capturing):
|
and not forward_context.capturing):
|
||||||
@@ -293,7 +334,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
if forward_context.sp_enabled:
|
||||||
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
model_previous_hidden_states, True)
|
model_previous_hidden_states, True)
|
||||||
|
|
||||||
@@ -383,19 +424,19 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
model_positions = self.positions[:num_input_tokens]
|
model_positions = self.positions[:num_input_tokens]
|
||||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
forward_context = get_forward_context()
|
||||||
|
if forward_context.sp_enabled:
|
||||||
# split hidden states along sequence dimension
|
# split hidden states along sequence dimension
|
||||||
# positions should not be split?
|
# positions should not be split?
|
||||||
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
model_hidden_states = split_inputs_tp_to_sp(
|
||||||
model_hidden_states)
|
model_hidden_states, model_hidden_states)
|
||||||
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
|
|
||||||
|
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=model_input_ids,
|
input_ids=model_input_ids,
|
||||||
positions=model_positions,
|
positions=model_positions,
|
||||||
hidden_states=model_hidden_states,
|
hidden_states=model_hidden_states,
|
||||||
)
|
)
|
||||||
forward_context = get_forward_context()
|
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
# TODO: support mla in future.
|
# TODO: support mla in future.
|
||||||
update_attn_params(
|
update_attn_params(
|
||||||
@@ -405,7 +446,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
if forward_context.sp_enabled:
|
||||||
# merge hidden states along sequence dimension
|
# merge hidden states along sequence dimension
|
||||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
last_hidden_states.contiguous(), True)
|
last_hidden_states.contiguous(), True)
|
||||||
@@ -536,19 +577,18 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
model_positions = self.positions[:input_batch_size]
|
model_positions = self.positions[:input_batch_size]
|
||||||
model_hidden_states = self.hidden_states[:input_batch_size]
|
model_hidden_states = self.hidden_states[:input_batch_size]
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
forward_context = get_forward_context()
|
||||||
|
if forward_context.sp_enabled:
|
||||||
# split hidden states along sequence dimension
|
# split hidden states along sequence dimension
|
||||||
# positions should not be split?
|
# positions should not be split?
|
||||||
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
model_hidden_states = split_inputs_tp_to_sp(
|
||||||
model_hidden_states)
|
model_hidden_states, model_hidden_states)
|
||||||
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
|
|
||||||
|
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=model_input_ids,
|
input_ids=model_input_ids,
|
||||||
positions=model_positions,
|
positions=model_positions,
|
||||||
hidden_states=model_hidden_states,
|
hidden_states=model_hidden_states,
|
||||||
)
|
)
|
||||||
forward_context = get_forward_context()
|
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
update_attn_params(
|
update_attn_params(
|
||||||
self.update_stream,
|
self.update_stream,
|
||||||
@@ -557,7 +597,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
if forward_context.sp_enabled:
|
||||||
# merge hidden states along sequence dimension
|
# merge hidden states along sequence dimension
|
||||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
last_hidden_states.contiguous(), True)
|
last_hidden_states.contiguous(), True)
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ _ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
|||||||
_DEFAULT_BUFFER_SIZE = 200
|
_DEFAULT_BUFFER_SIZE = 200
|
||||||
_MIN_DP_BUFFER_SIZE = 50
|
_MIN_DP_BUFFER_SIZE = 50
|
||||||
_IS_MOE_MODEL = None
|
_IS_MOE_MODEL = None
|
||||||
|
_IS_DRAFTER_MOE_MODEL = None
|
||||||
_IS_VL_MODEL = None
|
_IS_VL_MODEL = None
|
||||||
_ENABLE_SP = None
|
_ENABLE_SP = None
|
||||||
_HAS_LAYER_IDX = None
|
_HAS_LAYER_IDX = None
|
||||||
@@ -842,6 +843,16 @@ def is_moe_model(vllm_config: VllmConfig):
|
|||||||
return _IS_MOE_MODEL
|
return _IS_MOE_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def is_drafter_moe_model(vllm_config: VllmConfig):
|
||||||
|
"""Checks if the drafter model is a MoE model by config"""
|
||||||
|
global _IS_DRAFTER_MOE_MODEL
|
||||||
|
if _IS_DRAFTER_MOE_MODEL is None:
|
||||||
|
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config \
|
||||||
|
.to_dict()
|
||||||
|
_IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs)
|
||||||
|
return _IS_DRAFTER_MOE_MODEL
|
||||||
|
|
||||||
|
|
||||||
def speculative_enable_dispatch_gmm_combine_decode(
|
def speculative_enable_dispatch_gmm_combine_decode(
|
||||||
vllm_config: VllmConfig) -> bool:
|
vllm_config: VllmConfig) -> bool:
|
||||||
if vllm_config.speculative_config is None:
|
if vllm_config.speculative_config is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user