From d886b81971f85efcda053e95daa58619f9b1769b Mon Sep 17 00:00:00 2001 From: zhaomingyu13 Date: Tue, 13 Jan 2026 09:14:30 +0800 Subject: [PATCH] [BugFix] Support setting tp=1 for the Eagle draft model to take effect (#5519) ### What this PR does / why we need it? According to the official documentation, the parameter "draft_tensor_parallel_size": 1 is supposed to be applied to the Eagle3 model. However, based on actual debugging, it was found that the number of tensor parallelisms (tp) of the Eagle model is consistent with that of the target model. The setting of tp for the draft model did not take effect as expected. **Note:** This feature has not been superimposed and tested with `sp` and `dp`. It will be adapted later ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ```python from vllm import LLM, SamplingParams def main(): prompts = [ "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. llm = LLM( model="meta-llama/Llama-3.1-8B-Instruct", tensor_parallel_size=4, gpu_memory_utilization=0.9, enforce_eager=True, speculative_config={ "method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" "draft_tensor_parallel_size": 1, "num_speculative_tokens": 3, }, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) print(f"Outputs: {outputs}") for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 Fixes vllm-project/vllm#31345 Signed-off-by: zhaomingyu Co-authored-by: drslark --- .../spec_decode/test_mtp_eagle_correctness.py | 13 ++++++---- .../spec_decode/test_v1_spec_decode.py | 8 +++++- tests/ut/spec_decode/test_eagle_proposer.py | 8 ++++++ tests/ut/spec_decode/test_mtp_proposer.py | 3 +++ vllm_ascend/spec_decode/eagle_proposer.py | 26 ++++++++++++++++++- vllm_ascend/worker/model_runner_v1.py | 19 ++++++++++---- 6 files changed, 65 insertions(+), 12 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py b/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py index 421a0e88..d1393971 100644 --- a/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py +++ b/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py @@ -23,6 +23,7 @@ from __future__ import annotations import os +from typing import Union import pytest from vllm import SamplingParams @@ -123,11 +124,11 @@ def test_deepseek_mtp_correctness(model_name: str, num_speculative_tokens: int, @pytest.mark.parametrize("method", ["eagle", "eagle3"]) @pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) @pytest.mark.parametrize("async_scheduling", [True, False]) -def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str, - num_speculative_tokens: int, - method: str, - disable_padded_drafter_batch: bool, - async_scheduling: bool): +@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1]) +def test_llama_qwen3_eagle_correctness( + model_name: str, model_name_main: str, num_speculative_tokens: int, + method: str, disable_padded_drafter_batch: bool, + async_scheduling: bool, draft_tensor_parallel_size: Union[None, int]): example_prompts = [ "Hello, my name is", @@ -162,6 +163,8 @@ def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str, "method": method, "model": model_name, "num_speculative_tokens": num_speculative_tokens, + "draft_tensor_parallel_size": + draft_tensor_parallel_size, "max_model_len": 128, "draft_vocab_size": 128256, }, diff --git a/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py index 5cca89d5..27a86aa2 100644 --- a/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py @@ -4,7 +4,7 @@ from __future__ import annotations import math import os import random -from typing import Any +from typing import Any, Union import pytest from transformers import AutoTokenizer @@ -222,9 +222,11 @@ def test_suffix_acceptance( @pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"]) +@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1]) def test_eagle_logprobs( model_name: str, use_eagle3: bool, + draft_tensor_parallel_size: Union[None, int], ): prompt = {"role": "user", "content": "Hello world " * 10} sampling_params = SamplingParams(temperature=0, @@ -251,6 +253,7 @@ def test_eagle_logprobs( "method": "eagle3" if use_eagle3 else "eagle", "model": spec_model_name, "num_speculative_tokens": 2, + "draft_tensor_parallel_size": draft_tensor_parallel_size, "max_model_len": 128, }, max_model_len=128, @@ -276,11 +279,13 @@ def test_eagle_logprobs( @pytest.mark.parametrize("method", MODELS.keys()) @pytest.mark.parametrize("num_speculative_tokens", [3]) +@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1]) @pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) @pytest.mark.parametrize("async_scheduling", [True, False]) def test_llama_qwen_eagle_acceptance( method: str, num_speculative_tokens: int, + draft_tensor_parallel_size: Union[None, int], disable_padded_drafter_batch: bool, async_scheduling: bool, ): @@ -331,6 +336,7 @@ def test_llama_qwen_eagle_acceptance( speculative_config = { "method": method, "num_speculative_tokens": num_speculative_tokens, + "draft_tensor_parallel_size": draft_tensor_parallel_size, "disable_padded_drafter_batch": disable_padded_drafter_batch, "model": spec_model_name, } diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 3e30ecc8..310167f8 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -27,6 +27,8 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False + self.vllm_config.parallel_config.tensor_parallel_size = 1 + self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(2) @@ -114,6 +116,8 @@ class TestEagleProposerLoadModel(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False + self.vllm_config.parallel_config.tensor_parallel_size = 1 + self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(2) @@ -246,6 +250,8 @@ class TestEagleProposerDummyRun(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False + self.vllm_config.parallel_config.tensor_parallel_size = 1 + self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(4) ]) @@ -352,6 +358,8 @@ class TestEagleProposerHelperMethods(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False + self.vllm_config.parallel_config.tensor_parallel_size = 1 + self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(2) diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 7c69c12c..906c0c65 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -42,6 +42,9 @@ class TestMtpProposer: config.model_config.max_model_len = 2048 config.model_config.uses_mrope = False config.model_config.hf_text_config = None + config.model_config.hf_config = None + config.parallel_config.tensor_parallel_size = 1 + config.speculative_config.draft_tensor_parallel_size = 1 config.load_config = None diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 72b6f1a1..d106a054 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from contextlib import nullcontext from typing import Optional import numpy as np @@ -7,7 +8,9 @@ import torch.nn as nn import torch.nn.functional as F from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) -from vllm.distributed.parallel_state import get_pp_group +from vllm.distributed.parallel_state import (get_pp_group, get_world_group, + init_model_parallel_group, + patch_tensor_parallel_group) from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -93,6 +96,27 @@ class EagleProposer(VllmEagleProposer): self.use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk") + # NOTE: + # `draft_tensor_parallel_size` does not take effect for Eagle: + # the draft model uses the same TP size as the target model in practice. + # so we applied this patch to set tp=1 of draft model separately. + # Due to verification of `_verify_and_get_draft_tp` in vllm, + # the value of `draft_tensor_parallel_size` here will either be 1 separately + # or the same as target model. + # TODO(zhaomingyu13): If we want to adapt to the case where draft model tp + # is not 1 and differs from target model, this part should be rewritten. + if (vllm_config.parallel_config.tensor_parallel_size + != self.speculative_config.draft_tensor_parallel_size): + tp_group = init_model_parallel_group( + [[get_world_group().rank]], + get_world_group().rank, + torch.distributed.get_backend(get_world_group().device_group), + use_message_queue_broadcaster=True, + group_name="tp", + ) + self.tp_group_context = patch_tensor_parallel_group(tp_group) + else: + self.tp_group_context = nullcontext() def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 757da788..ff0ff8f5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -165,6 +165,10 @@ def graph_capture(device: torch.device): yield graph_capture_context +def get_tp_context(drafter): + return getattr(drafter, "tp_group_context", nullcontext()) + + class ExecuteModelState(NamedTuple): """Ephemeral cached state transferred between execute_model() and sample_tokens(), after execute_model() returns None.""" @@ -2320,7 +2324,8 @@ class NPUModelRunner(GPUModelRunner): model_register(self.model, self.model_config) if self.drafter: logger.info("Loading drafter model...") - self.drafter.load_model(self.model) + with get_tp_context(self.drafter): + self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: self.model.set_aux_hidden_state_layers( self.model.get_eagle3_aux_hidden_state_layers()) @@ -2696,11 +2701,15 @@ class NPUModelRunner(GPUModelRunner): kernel_block_sizes = [] for kv_cache_group_id, kv_cache_group in enumerate( kv_cache_config.kv_cache_groups): - - if isinstance(kv_cache_group.kv_cache_spec, - EncoderOnlyAttentionSpec): + kv_cache_spec = kv_cache_group.kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + # All layers in the UniformTypeKVCacheSpecs have the same type, + # Pick an arbitrary one to dispatch. + kv_cache_spec = next( + iter(kv_cache_spec.kv_cache_specs.values())) + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): continue - elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + elif isinstance(kv_cache_spec, AttentionSpec): # This is an attention backend that supports virtual # block splitting. Get the supported block sizes from # the backend.