[Eagle3]enhance skipping dp allreduce and add it into eagle proposer (#6192)
### What this PR does / why we need it?
This PR:
1. Enhances the logic of `_skip_all_reduce_across_dp_group` to skip all
cpu dp allreduce for dense models. This is also for purpose 2.
2. Adds `_skip_all_reduce_across_dp_group` into eagle_proposer. Now
models like Qwen3-235b supports eagle3 spec decode. A typical setting
for these moe models on pd disaggregation often introduce `dp_size > 1`.
This requires `set_forward_context` to call a cpu dp allreduce to
retrieve `num_tokens_across_dp` on all cases. Skipping this allreduce
greatly improves performance.
- vLLM version: v0.14.0
- vLLM main:
d68209402d
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -272,6 +272,7 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.runner.pcp_size = 1
|
||||
self.runner.dcp_size = 1
|
||||
self.runner.pin_memory = False
|
||||
self.runner._sync_metadata_across_dp.return_value = (8, torch.tensor([8]), False)
|
||||
|
||||
self.vllm_config.cache_config.block_size = 16
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
||||
|
||||
@@ -382,10 +382,17 @@ class EagleProposer(VllmEagleProposer):
|
||||
model_previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
|
||||
batch_size = num_tokens // (self.num_speculative_tokens + 1)
|
||||
(
|
||||
num_tokens,
|
||||
num_tokens_across_dp,
|
||||
_,
|
||||
) = self.runner._sync_metadata_across_dp(num_tokens,
|
||||
is_draft_model=True)
|
||||
with set_ascend_forward_context(
|
||||
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
num_actual_tokens=0,
|
||||
in_profile_run=is_profile,
|
||||
batch_descriptor=batch_descriptor,
|
||||
@@ -531,10 +538,17 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.last_token_indices[:last_token_indices_len].copy_(
|
||||
last_token_indices)
|
||||
|
||||
(
|
||||
num_input_tokens,
|
||||
num_tokens_across_dp,
|
||||
_,
|
||||
) = self.runner._sync_metadata_across_dp(num_input_tokens,
|
||||
is_draft_model=True)
|
||||
with set_ascend_forward_context(
|
||||
multi_steps_attn_metadata[0],
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
num_actual_tokens=num_tokens,
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
|
||||
@@ -103,7 +103,8 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_device_type, is_moe_model,
|
||||
enable_sp, get_ascend_device_type,
|
||||
is_drafter_moe_model, is_moe_model,
|
||||
lmhead_tp_enable, maybe_trans_nz,
|
||||
set_weight_prefetch_method)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
@@ -393,17 +394,24 @@ class NPUModelRunner(GPUModelRunner):
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
|
||||
|
||||
def _skip_all_reduce_across_dp_group(self) -> bool:
|
||||
def _skip_all_reduce_across_dp_group(self, is_draft_model=False) -> bool:
|
||||
"""
|
||||
Decide whether to skip the all-reduce across the data-parallel (DP) group.
|
||||
|
||||
Skipping is only applicable for MoE models and only on ranks that act as
|
||||
KV consumers. We skip the DP all-reduce when either:
|
||||
Skipping is applicable for all dense models and for moe models only on ranks
|
||||
that act as KV consumers. We skip the DP all-reduce when either:
|
||||
- Both the prefill and decode communication methods are MC2 (or FUSED_MC2), or
|
||||
- Decode requires MC2 and ascend_config.recompute_scheduler_enable is True.
|
||||
"""
|
||||
# Only applicable to MoE models and KV consumer ranks.
|
||||
if not is_moe_model(self.vllm_config) or not self.is_kv_consumer:
|
||||
# For dense models, since we don't actually need dp communication, we simply skip it.
|
||||
# This usually happens when main model is moe while eagle draft model is dense.
|
||||
is_context_moe_model = is_drafter_moe_model(self.vllm_config) if is_draft_model \
|
||||
else is_moe_model(self.vllm_config)
|
||||
if not is_context_moe_model:
|
||||
return True
|
||||
|
||||
# Only applicable to MoE models on KV consumer ranks.
|
||||
if not self.is_kv_consumer:
|
||||
return False
|
||||
|
||||
def needs_mc2(num_tokens: int) -> bool:
|
||||
@@ -431,8 +439,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
or self.ascend_config.recompute_scheduler_enable)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
self,
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
is_draft_model: bool = False
|
||||
) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
|
||||
# our case, we still need to sync the other two flags as well. So we need to
|
||||
# include them in the all_reduce operation, and more over, we CANNOT skip it
|
||||
@@ -442,7 +453,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.dp_size == 1:
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
if self._skip_all_reduce_across_dp_group():
|
||||
if self._skip_all_reduce_across_dp_group(is_draft_model):
|
||||
num_tokens_after_padding = torch.tensor([num_tokens] *
|
||||
self.dp_size,
|
||||
device="cpu",
|
||||
|
||||
Reference in New Issue
Block a user