[releases/v0.18.0][BugFix] Restore global_bs=0 and mc2_mask for uniform-token dispatching and support inter-node roce hierarchical MC2 communication (#8040)

### What this PR does / why we need it? 
Cherry-picked from #8039 
Restore the setting of MC2 `global_bs` and `mc2_mask` handling when
`all_reduce` across DP group cannot be skipped. Ascend MC2 ops require
`global_bs=0` + `mc2_mask` while enabling inter-node roce hierarchical
communication. PR #4983 always passed non-zero `global_bs` without
`mc2_mask`, which is incompatible with hierarchy comm raised in PR #7583
  **Changes:**                           
- Add `should_skip_allreduce_across_dp_group()` to `utils.py` with
hierarchy constraint
- Set `global_bs=0` when allreduce is not skipped; pass `mc2_mask`
accordingly
- Add `mc2_mask` field to `MoEMC2CombineMetadata` for dispatch→combine
propagation
### Does this PR introduce _any_ user-facing change?
No. But this PR fixes cross-super-node communication function on A3 with
`enable_mc2_hierarchy_comm=True` in `additional_config` and `export
HCCL_INTRA_ROCE_ENABLE=1`.

### How was this patch tested?
E2E serving succeeded and CI pssed.

- vLLM version: v0.18.0
- vLLM main:
14acf429ac

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-04-09 16:51:17 +08:00
committed by GitHub
parent 82e17f693a
commit 7c9aa498d6
4 changed files with 78 additions and 46 deletions

View File

@@ -997,6 +997,62 @@ def is_hierarchical_communication_enabled():
) or get_ascend_config().enable_mc2_hierarchy_comm
def should_skip_allreduce_across_dp_group(vllm_config, is_draft_model: bool = False) -> bool:
"""Decide whether to skip the all-reduce across the DP group.
Skipping is applicable for all dense models and for moe models only on ranks
that act as KV consumers. We skip the DP all-reduce when either:
- Both the prefill and decode communication methods are MC2 (or FUSED_MC2), or
- Decode requires MC2 and ascend_config.recompute_scheduler_enable is True.
Skipping means each rank may have a different number of tokens, so MC2 needs
a non-zero global_bs and must NOT receive mc2_mask.
Returns False when hierarchy comm is enabled because hierarchy requires
global_bs=0 (uniform tokens), which is incompatible with skipping allreduce.
"""
if is_hierarchical_communication_enabled():
return False
# For dense models, since we don't actually need dp communication, we simply skip it.
# This usually happens when main model is moe while eagle draft model is dense.
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if not is_context_moe_model:
return True
# Only applicable to MoE models on KV consumer ranks.
is_kv_consumer = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer
if not is_kv_consumer:
return False
from vllm_ascend.ascend_forward_context import select_moe_comm_method
from vllm_ascend.ops.fused_moe.moe_comm_method import MoECommType
def needs_mc2(n: int) -> bool:
return select_moe_comm_method(n, vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
compilation_config = vllm_config.compilation_config
scheduler_config = vllm_config.scheduler_config
speculative_config = vllm_config.speculative_config
uniform_decode_query_len = 1 if not speculative_config else 1 + speculative_config.num_speculative_tokens
decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0)
max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
# Determine whether decode must use MC2. Use max cudagraph capture size
# if available, otherwise use the maximal uniform decode token count.
if compilation_config.cudagraph_capture_sizes:
potential_max_tokens = compilation_config.max_cudagraph_capture_size
else:
potential_max_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
decode_must_use_mc2 = needs_mc2(potential_max_tokens)
# For prefill, use the scheduler's max_num_batched_tokens for a single batch.
prefill_must_use_mc2 = needs_mc2(scheduler_config.max_num_batched_tokens)
# Skip all-reduce if decode requires MC2 and either prefill also
# requires MC2 or recompute-based scheduler is enabled.
return decode_must_use_mc2 and (prefill_must_use_mc2 or get_ascend_config().recompute_scheduler_enable)
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
if model_instance is None:
return False