[BugFix] Support setting tp=1 for the Eagle draft model to take effect (#6097)
According to the official documentation, the parameter
"draft_tensor_parallel_size": 1 is supposed to be applied to the Eagle3
model. However, based on actual debugging, it was found that the number
of tensor parallelisms (tp) of the Eagle model is consistent with that
of the target model. The setting of tp for the draft model did not take
effect as expected.
**Note:** This feature has not been superimposed and tested with `sp`
and `dp`. It will be adapted later
No
```python
from vllm import LLM, SamplingParams
def main():
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
tensor_parallel_size=4,
gpu_memory_utilization=0.9,
enforce_eager=True,
speculative_config={
"method": "eagle3",
"model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
"draft_tensor_parallel_size": 1,
"num_speculative_tokens": 3,
},
)
outputs = llm.generate(prompts, sampling_params)
print(f"Outputs: {outputs}")
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Fixes vllm-project/vllm#31345
### What this PR does / why we need it?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Co-authored-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -23,6 +23,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
@@ -124,11 +125,11 @@ def test_deepseek_mtp_correctness(model_name: str, num_speculative_tokens: int,
|
|||||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||||
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
||||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||||
def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str,
|
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
|
||||||
num_speculative_tokens: int,
|
def test_llama_qwen3_eagle_correctness(
|
||||||
method: str,
|
model_name: str, model_name_main: str, num_speculative_tokens: int,
|
||||||
disable_padded_drafter_batch: bool,
|
method: str, disable_padded_drafter_batch: bool,
|
||||||
async_scheduling: bool):
|
async_scheduling: bool, draft_tensor_parallel_size: Union[None, int]):
|
||||||
|
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -163,6 +164,8 @@ def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str,
|
|||||||
"method": method,
|
"method": method,
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"num_speculative_tokens": num_speculative_tokens,
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
|
"draft_tensor_parallel_size":
|
||||||
|
draft_tensor_parallel_size,
|
||||||
"max_model_len": 128,
|
"max_model_len": 128,
|
||||||
"draft_vocab_size": 128256,
|
"draft_vocab_size": 128256,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
@@ -267,9 +267,11 @@ def test_suffix_acceptance(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
|
@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
|
||||||
|
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
|
||||||
def test_eagle_logprobs(
|
def test_eagle_logprobs(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
use_eagle3: bool,
|
use_eagle3: bool,
|
||||||
|
draft_tensor_parallel_size: Union[None, int],
|
||||||
):
|
):
|
||||||
prompt = {"role": "user", "content": "Hello world " * 10}
|
prompt = {"role": "user", "content": "Hello world " * 10}
|
||||||
sampling_params = SamplingParams(temperature=0,
|
sampling_params = SamplingParams(temperature=0,
|
||||||
@@ -296,6 +298,7 @@ def test_eagle_logprobs(
|
|||||||
"method": "eagle3" if use_eagle3 else "eagle",
|
"method": "eagle3" if use_eagle3 else "eagle",
|
||||||
"model": spec_model_name,
|
"model": spec_model_name,
|
||||||
"num_speculative_tokens": 2,
|
"num_speculative_tokens": 2,
|
||||||
|
"draft_tensor_parallel_size": draft_tensor_parallel_size,
|
||||||
"max_model_len": 128,
|
"max_model_len": 128,
|
||||||
},
|
},
|
||||||
max_model_len=128,
|
max_model_len=128,
|
||||||
@@ -321,11 +324,13 @@ def test_eagle_logprobs(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("method", MODELS.keys())
|
@pytest.mark.parametrize("method", MODELS.keys())
|
||||||
@pytest.mark.parametrize("num_speculative_tokens", [3])
|
@pytest.mark.parametrize("num_speculative_tokens", [3])
|
||||||
|
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
|
||||||
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
||||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||||
def test_llama_qwen_eagle_acceptance(
|
def test_llama_qwen_eagle_acceptance(
|
||||||
method: str,
|
method: str,
|
||||||
num_speculative_tokens: int,
|
num_speculative_tokens: int,
|
||||||
|
draft_tensor_parallel_size: Union[None, int],
|
||||||
disable_padded_drafter_batch: bool,
|
disable_padded_drafter_batch: bool,
|
||||||
async_scheduling: bool,
|
async_scheduling: bool,
|
||||||
):
|
):
|
||||||
@@ -376,6 +381,7 @@ def test_llama_qwen_eagle_acceptance(
|
|||||||
speculative_config = {
|
speculative_config = {
|
||||||
"method": method,
|
"method": method,
|
||||||
"num_speculative_tokens": num_speculative_tokens,
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
|
"draft_tensor_parallel_size": draft_tensor_parallel_size,
|
||||||
"disable_padded_drafter_batch": disable_padded_drafter_batch,
|
"disable_padded_drafter_batch": disable_padded_drafter_batch,
|
||||||
"model": spec_model_name,
|
"model": spec_model_name,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
self.vllm_config.model_config.uses_mrope = False
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
|
||||||
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
||||||
self.vllm_config.speculative_config.speculative_token_tree = str([
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
(i + 1) * (0, ) for i in range(2)
|
(i + 1) * (0, ) for i in range(2)
|
||||||
@@ -115,6 +117,8 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
self.vllm_config.model_config.uses_mrope = False
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
|
||||||
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
||||||
self.vllm_config.speculative_config.speculative_token_tree = str([
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
(i + 1) * (0, ) for i in range(2)
|
(i + 1) * (0, ) for i in range(2)
|
||||||
@@ -256,6 +260,8 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
self.vllm_config.model_config.uses_mrope = False
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
self.vllm_config.model_config.use_mla = False
|
self.vllm_config.model_config.use_mla = False
|
||||||
|
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
|
||||||
self.vllm_config.speculative_config.speculative_token_tree = str([
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
(i + 1) * (0, ) for i in range(4)
|
(i + 1) * (0, ) for i in range(4)
|
||||||
])
|
])
|
||||||
@@ -370,6 +376,8 @@ class TestEagleProposerHelperMethods(TestBase):
|
|||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
self.vllm_config.model_config.uses_mrope = False
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
|
||||||
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
||||||
self.vllm_config.speculative_config.speculative_token_tree = str([
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
(i + 1) * (0, ) for i in range(2)
|
(i + 1) * (0, ) for i in range(2)
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ class TestMtpProposer:
|
|||||||
config.model_config.max_model_len = 2048
|
config.model_config.max_model_len = 2048
|
||||||
config.model_config.uses_mrope = False
|
config.model_config.uses_mrope = False
|
||||||
config.model_config.hf_text_config = None
|
config.model_config.hf_text_config = None
|
||||||
|
config.model_config.hf_config = None
|
||||||
|
config.parallel_config.tensor_parallel_size = 1
|
||||||
|
config.speculative_config.draft_tensor_parallel_size = 1
|
||||||
|
|
||||||
config.load_config = None
|
config.load_config = None
|
||||||
|
|
||||||
|
|||||||
@@ -115,6 +115,27 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
|
|
||||||
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
||||||
"index_topk")
|
"index_topk")
|
||||||
|
# NOTE:
|
||||||
|
# `draft_tensor_parallel_size` does not take effect for Eagle:
|
||||||
|
# the draft model uses the same TP size as the target model in practice.
|
||||||
|
# so we applied this patch to set tp=1 of draft model separately.
|
||||||
|
# Due to verification of `_verify_and_get_draft_tp` in vllm,
|
||||||
|
# the value of `draft_tensor_parallel_size` here will either be 1 separately
|
||||||
|
# or the same as target model.
|
||||||
|
# TODO(zhaomingyu13): If we want to adapt to the case where draft model tp
|
||||||
|
# is not 1 and differs from target model, this part should be rewritten.
|
||||||
|
if (vllm_config.parallel_config.tensor_parallel_size
|
||||||
|
!= self.speculative_config.draft_tensor_parallel_size):
|
||||||
|
tp_group = init_model_parallel_group(
|
||||||
|
[[get_world_group().rank]],
|
||||||
|
get_world_group().rank,
|
||||||
|
torch.distributed.get_backend(get_world_group().device_group),
|
||||||
|
use_message_queue_broadcaster=True,
|
||||||
|
group_name="tp",
|
||||||
|
)
|
||||||
|
self.tp_group_context = patch_tensor_parallel_group(tp_group)
|
||||||
|
else:
|
||||||
|
self.tp_group_context = nullcontext()
|
||||||
|
|
||||||
self.use_cuda_graph = (self.runner._use_aclgraph()
|
self.use_cuda_graph = (self.runner._use_aclgraph()
|
||||||
and not self.speculative_config.enforce_eager
|
and not self.speculative_config.enforce_eager
|
||||||
|
|||||||
@@ -170,6 +170,10 @@ def graph_capture(device: torch.device):
|
|||||||
yield graph_capture_context
|
yield graph_capture_context
|
||||||
|
|
||||||
|
|
||||||
|
def get_tp_context(drafter):
|
||||||
|
return getattr(drafter, "tp_group_context", nullcontext())
|
||||||
|
|
||||||
|
|
||||||
class ExecuteModelState(NamedTuple):
|
class ExecuteModelState(NamedTuple):
|
||||||
"""Ephemeral cached state transferred between execute_model() and
|
"""Ephemeral cached state transferred between execute_model() and
|
||||||
sample_tokens(), after execute_model() returns None."""
|
sample_tokens(), after execute_model() returns None."""
|
||||||
@@ -2339,7 +2343,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
model_register(self.model, self.model_config)
|
model_register(self.model, self.model_config)
|
||||||
if self.drafter:
|
if self.drafter:
|
||||||
logger.info("Loading drafter model...")
|
logger.info("Loading drafter model...")
|
||||||
self.drafter.load_model(self.model)
|
with get_tp_context(self.drafter):
|
||||||
|
self.drafter.load_model(self.model)
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
self.model.set_aux_hidden_state_layers(
|
self.model.set_aux_hidden_state_layers(
|
||||||
self.model.get_eagle3_aux_hidden_state_layers())
|
self.model.get_eagle3_aux_hidden_state_layers())
|
||||||
@@ -2715,11 +2720,15 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
kernel_block_sizes = []
|
kernel_block_sizes = []
|
||||||
for kv_cache_group_id, kv_cache_group in enumerate(
|
for kv_cache_group_id, kv_cache_group in enumerate(
|
||||||
kv_cache_config.kv_cache_groups):
|
kv_cache_config.kv_cache_groups):
|
||||||
|
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||||
if isinstance(kv_cache_group.kv_cache_spec,
|
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||||
EncoderOnlyAttentionSpec):
|
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||||
|
# Pick an arbitrary one to dispatch.
|
||||||
|
kv_cache_spec = next(
|
||||||
|
iter(kv_cache_spec.kv_cache_specs.values()))
|
||||||
|
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||||
continue
|
continue
|
||||||
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
|
elif isinstance(kv_cache_spec, AttentionSpec):
|
||||||
# This is an attention backend that supports virtual
|
# This is an attention backend that supports virtual
|
||||||
# block splitting. Get the supported block sizes from
|
# block splitting. Get the supported block sizes from
|
||||||
# the backend.
|
# the backend.
|
||||||
|
|||||||
Reference in New Issue
Block a user