[releases/v0.18.0][BugFix] Restore global_bs=0 and mc2_mask for uniform-token dispatching and support inter-node roce hierarchical MC2 communication (#8040)

### What this PR does / why we need it? 
Cherry-picked from #8039 
Restore the setting of MC2 `global_bs` and `mc2_mask` handling when
`all_reduce` across DP group cannot be skipped. Ascend MC2 ops require
`global_bs=0` + `mc2_mask` while enabling inter-node roce hierarchical
communication. PR #4983 always passed non-zero `global_bs` without
`mc2_mask`, which is incompatible with hierarchy comm raised in PR #7583
  **Changes:**                           
- Add `should_skip_allreduce_across_dp_group()` to `utils.py` with
hierarchy constraint
- Set `global_bs=0` when allreduce is not skipped; pass `mc2_mask`
accordingly
- Add `mc2_mask` field to `MoEMC2CombineMetadata` for dispatch→combine
propagation
### Does this PR introduce _any_ user-facing change?
No. But this PR fixes cross-super-node communication function on A3 with
`enable_mc2_hierarchy_comm=True` in `additional_config` and `export
HCCL_INTRA_ROCE_ENABLE=1`.

### How was this patch tested?
E2E serving succeeded and CI pssed.

- vLLM version: v0.18.0
- vLLM main:
14acf429ac

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-04-09 16:51:17 +08:00
committed by GitHub
parent 82e17f693a
commit 7c9aa498d6
4 changed files with 78 additions and 46 deletions

View File

@@ -92,6 +92,7 @@ class MoEMC2CombineMetadata:
assist_info_for_combine: torch.Tensor
expand_scales: torch.Tensor | None
dispatch_with_quant: bool
mc2_mask: torch.Tensor | None = None
@dataclass(frozen=True, slots=True)

View File

@@ -40,7 +40,12 @@ from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoETokenDispatchOutput,
TMoECombineMetadata,
)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
from vllm_ascend.utils import (
AscendDeviceType,
get_ascend_device_type,
is_hierarchical_communication_enabled,
should_skip_allreduce_across_dp_group,
)
class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]):
@@ -115,7 +120,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
else:
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
_max_global_bs = num_tokens_per_tp_rank * self.ep_world_size
# When allreduce across DP is not skipped, tokens are uniform across ranks:
# use global_bs=0 (uniform mode) and pass mc2_mask.
# When allreduce is skipped, tokens may differ per rank:
# use the real global_bs and do NOT pass mc2_mask.
self.global_bs = _max_global_bs if should_skip_allreduce_across_dp_group(vllm_config) else 0
# NOTE: When enable_mc2_hierarchy_comm is true, we need pass in `comm_alg` to mc2 op.
self.need_comm_alg = get_ascend_config().enable_mc2_hierarchy_comm
@@ -156,6 +167,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
"global_bs": self.global_bs,
"expert_token_nums_type": 0,
}
if self.global_bs == 0:
kwargs_mc2["x_active_mask"] = token_dispatch_input.routing.mc2_mask
stage1_kwargs = {
"scales": None,
@@ -228,6 +241,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
assist_info_for_combine=assist_info_for_combine,
expand_scales=expand_scales,
dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant,
mc2_mask=token_dispatch_input.routing.mc2_mask if self.global_bs == 0 else None,
),
)
@@ -251,6 +265,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
"moe_expert_num": self.moe_expert_num,
"global_bs": self.global_bs,
}
if self.global_bs == 0:
kwargs_mc2["x_active_mask"] = combine_metadata.mc2_mask
if combine_metadata.dispatch_with_quant:
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)