[BugFix] Support setting tp=1 for the Eagle draft model to take effect (#5519)

### What this PR does / why we need it?
According to the official documentation, the parameter
"draft_tensor_parallel_size": 1 is supposed to be applied to the Eagle3
model. However, based on actual debugging, it was found that the number
of tensor parallelisms (tp) of the Eagle model is consistent with that
of the target model. The setting of tp for the draft model did not take
effect as expected.

**Note:** This feature has not been superimposed and tested with `sp`
and `dp`. It will be adapted later
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```python
from vllm import LLM, SamplingParams

def main():
    prompts = [
        "The future of AI is",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
    # Create an LLM.
    llm = LLM(
            model="meta-llama/Llama-3.1-8B-Instruct",
            tensor_parallel_size=4,
            gpu_memory_utilization=0.9,
            enforce_eager=True,
            speculative_config={
                "method": "eagle3",
                "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
                "draft_tensor_parallel_size": 1,
                "num_speculative_tokens": 3,
            },
        )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    print(f"Outputs: {outputs}")
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

Fixes vllm-project/vllm#31345

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Co-authored-by: drslark <slarksblood@qq.com>
This commit is contained in:
zhaomingyu13
2026-01-13 09:14:30 +08:00
committed by GitHub
parent 7af3b880c1
commit d886b81971
6 changed files with 65 additions and 12 deletions

View File

@@ -23,6 +23,7 @@
from __future__ import annotations
import os
from typing import Union
import pytest
from vllm import SamplingParams
@@ -123,11 +124,11 @@ def test_deepseek_mtp_correctness(model_name: str, num_speculative_tokens: int,
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str,
num_speculative_tokens: int,
method: str,
disable_padded_drafter_batch: bool,
async_scheduling: bool):
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
def test_llama_qwen3_eagle_correctness(
model_name: str, model_name_main: str, num_speculative_tokens: int,
method: str, disable_padded_drafter_batch: bool,
async_scheduling: bool, draft_tensor_parallel_size: Union[None, int]):
example_prompts = [
"Hello, my name is",
@@ -162,6 +163,8 @@ def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str,
"method": method,
"model": model_name,
"num_speculative_tokens": num_speculative_tokens,
"draft_tensor_parallel_size":
draft_tensor_parallel_size,
"max_model_len": 128,
"draft_vocab_size": 128256,
},

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import math
import os
import random
from typing import Any
from typing import Any, Union
import pytest
from transformers import AutoTokenizer
@@ -222,9 +222,11 @@ def test_suffix_acceptance(
@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
def test_eagle_logprobs(
model_name: str,
use_eagle3: bool,
draft_tensor_parallel_size: Union[None, int],
):
prompt = {"role": "user", "content": "Hello world " * 10}
sampling_params = SamplingParams(temperature=0,
@@ -251,6 +253,7 @@ def test_eagle_logprobs(
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"draft_tensor_parallel_size": draft_tensor_parallel_size,
"max_model_len": 128,
},
max_model_len=128,
@@ -276,11 +279,13 @@ def test_eagle_logprobs(
@pytest.mark.parametrize("method", MODELS.keys())
@pytest.mark.parametrize("num_speculative_tokens", [3])
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_llama_qwen_eagle_acceptance(
method: str,
num_speculative_tokens: int,
draft_tensor_parallel_size: Union[None, int],
disable_padded_drafter_batch: bool,
async_scheduling: bool,
):
@@ -331,6 +336,7 @@ def test_llama_qwen_eagle_acceptance(
speculative_config = {
"method": method,
"num_speculative_tokens": num_speculative_tokens,
"draft_tensor_parallel_size": draft_tensor_parallel_size,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
"model": spec_model_name,
}

View File

@@ -27,6 +27,8 @@ class TestEagleProposerInitialization(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
@@ -114,6 +116,8 @@ class TestEagleProposerLoadModel(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
@@ -246,6 +250,8 @@ class TestEagleProposerDummyRun(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(4)
])
@@ -352,6 +358,8 @@ class TestEagleProposerHelperMethods(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)

View File

@@ -42,6 +42,9 @@ class TestMtpProposer:
config.model_config.max_model_len = 2048
config.model_config.uses_mrope = False
config.model_config.hf_text_config = None
config.model_config.hf_config = None
config.parallel_config.tensor_parallel_size = 1
config.speculative_config.draft_tensor_parallel_size = 1
config.load_config = None

View File

@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext
from typing import Optional
import numpy as np
@@ -7,7 +8,9 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.parallel_state import (get_pp_group, get_world_group,
init_model_parallel_group,
patch_tensor_parallel_group)
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -93,6 +96,27 @@ class EagleProposer(VllmEagleProposer):
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
"index_topk")
# NOTE:
# `draft_tensor_parallel_size` does not take effect for Eagle:
# the draft model uses the same TP size as the target model in practice.
# so we applied this patch to set tp=1 of draft model separately.
# Due to verification of `_verify_and_get_draft_tp` in vllm,
# the value of `draft_tensor_parallel_size` here will either be 1 separately
# or the same as target model.
# TODO(zhaomingyu13): If we want to adapt to the case where draft model tp
# is not 1 and differs from target model, this part should be rewritten.
if (vllm_config.parallel_config.tensor_parallel_size
!= self.speculative_config.draft_tensor_parallel_size):
tp_group = init_model_parallel_group(
[[get_world_group().rank]],
get_world_group().rank,
torch.distributed.get_backend(get_world_group().device_group),
use_message_queue_broadcaster=True,
group_name="tp",
)
self.tp_group_context = patch_tensor_parallel_group(tp_group)
else:
self.tp_group_context = nullcontext()
def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set(

View File

@@ -165,6 +165,10 @@ def graph_capture(device: torch.device):
yield graph_capture_context
def get_tp_context(drafter):
return getattr(drafter, "tp_group_context", nullcontext())
class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
@@ -2320,7 +2324,8 @@ class NPUModelRunner(GPUModelRunner):
model_register(self.model, self.model_config)
if self.drafter:
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
with get_tp_context(self.drafter):
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
@@ -2696,11 +2701,15 @@ class NPUModelRunner(GPUModelRunner):
kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate(
kv_cache_config.kv_cache_groups):
if isinstance(kv_cache_group.kv_cache_spec,
EncoderOnlyAttentionSpec):
kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type,
# Pick an arbitrary one to dispatch.
kv_cache_spec = next(
iter(kv_cache_spec.kv_cache_specs.values()))
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
continue
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
elif isinstance(kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from
# the backend.