[releases/v0.18.0][BugFix] Restore global_bs=0 and mc2_mask for uniform-token dispatching and support inter-node roce hierarchical MC2 communication (#8040)
### What this PR does / why we need it?
Cherry-picked from #8039
Restore the setting of MC2 `global_bs` and `mc2_mask` handling when
`all_reduce` across DP group cannot be skipped. Ascend MC2 ops require
`global_bs=0` + `mc2_mask` while enabling inter-node roce hierarchical
communication. PR #4983 always passed non-zero `global_bs` without
`mc2_mask`, which is incompatible with hierarchy comm raised in PR #7583
**Changes:**
- Add `should_skip_allreduce_across_dp_group()` to `utils.py` with
hierarchy constraint
- Set `global_bs=0` when allreduce is not skipped; pass `mc2_mask`
accordingly
- Add `mc2_mask` field to `MoEMC2CombineMetadata` for dispatch→combine
propagation
### Does this PR introduce _any_ user-facing change?
No. But this PR fixes cross-super-node communication function on A3 with
`enable_mc2_hierarchy_comm=True` in `additional_config` and `export
HCCL_INTRA_ROCE_ENABLE=1`.
### How was this patch tested?
E2E serving succeeded and CI pssed.
- vLLM version: v0.18.0
- vLLM main:
14acf429ac
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -92,6 +92,7 @@ class MoEMC2CombineMetadata:
|
||||
assist_info_for_combine: torch.Tensor
|
||||
expand_scales: torch.Tensor | None
|
||||
dispatch_with_quant: bool
|
||||
mc2_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
|
||||
@@ -40,7 +40,12 @@ from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoETokenDispatchOutput,
|
||||
TMoECombineMetadata,
|
||||
)
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled,
|
||||
should_skip_allreduce_across_dp_group,
|
||||
)
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]):
|
||||
@@ -115,7 +120,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
else:
|
||||
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||
_max_global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||
|
||||
# When allreduce across DP is not skipped, tokens are uniform across ranks:
|
||||
# use global_bs=0 (uniform mode) and pass mc2_mask.
|
||||
# When allreduce is skipped, tokens may differ per rank:
|
||||
# use the real global_bs and do NOT pass mc2_mask.
|
||||
self.global_bs = _max_global_bs if should_skip_allreduce_across_dp_group(vllm_config) else 0
|
||||
|
||||
# NOTE: When enable_mc2_hierarchy_comm is true, we need pass in `comm_alg` to mc2 op.
|
||||
self.need_comm_alg = get_ascend_config().enable_mc2_hierarchy_comm
|
||||
@@ -156,6 +167,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
"global_bs": self.global_bs,
|
||||
"expert_token_nums_type": 0,
|
||||
}
|
||||
if self.global_bs == 0:
|
||||
kwargs_mc2["x_active_mask"] = token_dispatch_input.routing.mc2_mask
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
@@ -228,6 +241,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
expand_scales=expand_scales,
|
||||
dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant,
|
||||
mc2_mask=token_dispatch_input.routing.mc2_mask if self.global_bs == 0 else None,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -251,6 +265,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
"moe_expert_num": self.moe_expert_num,
|
||||
"global_bs": self.global_bs,
|
||||
}
|
||||
if self.global_bs == 0:
|
||||
kwargs_mc2["x_active_mask"] = combine_metadata.mc2_mask
|
||||
|
||||
if combine_metadata.dispatch_with_quant:
|
||||
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
Reference in New Issue
Block a user