[releases/v0.18.0][BugFix] Restore global_bs=0 and mc2_mask for uniform-token dispatching and support inter-node roce hierarchical MC2 communication (#8040)
### What this PR does / why we need it?
Cherry-picked from #8039
Restore the setting of MC2 `global_bs` and `mc2_mask` handling when
`all_reduce` across DP group cannot be skipped. Ascend MC2 ops require
`global_bs=0` + `mc2_mask` while enabling inter-node roce hierarchical
communication. PR #4983 always passed non-zero `global_bs` without
`mc2_mask`, which is incompatible with hierarchy comm raised in PR #7583
**Changes:**
- Add `should_skip_allreduce_across_dp_group()` to `utils.py` with
hierarchy constraint
- Set `global_bs=0` when allreduce is not skipped; pass `mc2_mask`
accordingly
- Add `mc2_mask` field to `MoEMC2CombineMetadata` for dispatch→combine
propagation
### Does this PR introduce _any_ user-facing change?
No. But this PR fixes cross-super-node communication function on A3 with
`enable_mc2_hierarchy_comm=True` in `additional_config` and `export
HCCL_INTRA_ROCE_ENABLE=1`.
### How was this patch tested?
E2E serving succeeded and CI pssed.
- vLLM version: v0.18.0
- vLLM main:
14acf429ac
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -123,10 +123,9 @@ from vllm_ascend.utils import (
|
||||
enable_sp,
|
||||
enable_sp_by_pass,
|
||||
global_stream,
|
||||
is_drafter_moe_model,
|
||||
is_moe_model,
|
||||
lmhead_tp_enable,
|
||||
set_weight_prefetch_method,
|
||||
should_skip_allreduce_across_dp_group,
|
||||
)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
from vllm_ascend.worker.pcp_utils import PCPManager
|
||||
@@ -469,46 +468,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
and not self.model_config.enforce_eager
|
||||
)
|
||||
|
||||
def _skip_all_reduce_across_dp_group(self, is_draft_model=False) -> bool:
|
||||
"""
|
||||
Decide whether to skip the all-reduce across the data-parallel (DP) group.
|
||||
|
||||
Skipping is applicable for all dense models and for moe models only on ranks
|
||||
that act as KV consumers. We skip the DP all-reduce when either:
|
||||
- Both the prefill and decode communication methods are MC2 (or FUSED_MC2), or
|
||||
- Decode requires MC2 and ascend_config.recompute_scheduler_enable is True.
|
||||
"""
|
||||
# For dense models, since we don't actually need dp communication, we simply skip it.
|
||||
# This usually happens when main model is moe while eagle draft model is dense.
|
||||
is_context_moe_model = (
|
||||
is_drafter_moe_model(self.vllm_config) if is_draft_model else is_moe_model(self.vllm_config)
|
||||
)
|
||||
if not is_context_moe_model:
|
||||
return True
|
||||
|
||||
# Only applicable to MoE models on KV consumer ranks.
|
||||
if not self.is_kv_consumer:
|
||||
return False
|
||||
|
||||
def needs_mc2(num_tokens: int) -> bool:
|
||||
return select_moe_comm_method(num_tokens, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
|
||||
# Determine whether decode must use MC2. Use max cudagraph capture size
|
||||
# if available, otherwise use the maximal uniform decode token count.
|
||||
if self.compilation_config.cudagraph_capture_sizes:
|
||||
potential_max_tokens = self.compilation_config.max_cudagraph_capture_size
|
||||
else:
|
||||
potential_max_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||
decode_must_use_mc2 = needs_mc2(potential_max_tokens)
|
||||
|
||||
# For prefill, use the scheduler's max_num_batched_tokens for a single
|
||||
# batch.
|
||||
prefill_must_use_mc2 = needs_mc2(self.vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# Skip all-reduce if decode requires MC2 and either prefill also
|
||||
# requires MC2 or recompute-based scheduler is enabled.
|
||||
return decode_must_use_mc2 and (prefill_must_use_mc2 or self.ascend_config.recompute_scheduler_enable)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool = False, is_draft_model: bool = False
|
||||
) -> tuple[int, torch.Tensor | None, bool]:
|
||||
@@ -521,7 +480,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.dp_size == 1:
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
if self._skip_all_reduce_across_dp_group(is_draft_model):
|
||||
if should_skip_allreduce_across_dp_group(self.vllm_config, is_draft_model):
|
||||
num_tokens_after_padding = torch.tensor([num_tokens] * self.dp_size, device="cpu", dtype=torch.int32)
|
||||
return num_tokens, num_tokens_after_padding, with_prefill
|
||||
|
||||
@@ -1891,7 +1850,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.dp_size == 1:
|
||||
return False, None, cudagraph_mode
|
||||
|
||||
if self._skip_all_reduce_across_dp_group():
|
||||
if should_skip_allreduce_across_dp_group(self.vllm_config):
|
||||
num_tokens_after_padding = torch.tensor([num_tokens_padded] * self.dp_size, device="cpu", dtype=torch.int32)
|
||||
return False, num_tokens_after_padding, cudagraph_mode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user