[releases/v0.18.0][BugFix] Restore global_bs=0 and mc2_mask for uniform-token dispatching and support inter-node roce hierarchical MC2 communication (#8040)
### What this PR does / why we need it?
Cherry-picked from #8039
Restore the setting of MC2 `global_bs` and `mc2_mask` handling when
`all_reduce` across DP group cannot be skipped. Ascend MC2 ops require
`global_bs=0` + `mc2_mask` while enabling inter-node roce hierarchical
communication. PR #4983 always passed non-zero `global_bs` without
`mc2_mask`, which is incompatible with hierarchy comm raised in PR #7583
**Changes:**
- Add `should_skip_allreduce_across_dp_group()` to `utils.py` with
hierarchy constraint
- Set `global_bs=0` when allreduce is not skipped; pass `mc2_mask`
accordingly
- Add `mc2_mask` field to `MoEMC2CombineMetadata` for dispatch→combine
propagation
### Does this PR introduce _any_ user-facing change?
No. But this PR fixes cross-super-node communication function on A3 with
`enable_mc2_hierarchy_comm=True` in `additional_config` and `export
HCCL_INTRA_ROCE_ENABLE=1`.
### How was this patch tested?
E2E serving succeeded and CI pssed.
- vLLM version: v0.18.0
- vLLM main:
14acf429ac
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -92,6 +92,7 @@ class MoEMC2CombineMetadata:
|
||||
assist_info_for_combine: torch.Tensor
|
||||
expand_scales: torch.Tensor | None
|
||||
dispatch_with_quant: bool
|
||||
mc2_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
|
||||
@@ -40,7 +40,12 @@ from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoETokenDispatchOutput,
|
||||
TMoECombineMetadata,
|
||||
)
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled,
|
||||
should_skip_allreduce_across_dp_group,
|
||||
)
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]):
|
||||
@@ -115,7 +120,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
else:
|
||||
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||
_max_global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||
|
||||
# When allreduce across DP is not skipped, tokens are uniform across ranks:
|
||||
# use global_bs=0 (uniform mode) and pass mc2_mask.
|
||||
# When allreduce is skipped, tokens may differ per rank:
|
||||
# use the real global_bs and do NOT pass mc2_mask.
|
||||
self.global_bs = _max_global_bs if should_skip_allreduce_across_dp_group(vllm_config) else 0
|
||||
|
||||
# NOTE: When enable_mc2_hierarchy_comm is true, we need pass in `comm_alg` to mc2 op.
|
||||
self.need_comm_alg = get_ascend_config().enable_mc2_hierarchy_comm
|
||||
@@ -156,6 +167,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
"global_bs": self.global_bs,
|
||||
"expert_token_nums_type": 0,
|
||||
}
|
||||
if self.global_bs == 0:
|
||||
kwargs_mc2["x_active_mask"] = token_dispatch_input.routing.mc2_mask
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
@@ -228,6 +241,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
expand_scales=expand_scales,
|
||||
dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant,
|
||||
mc2_mask=token_dispatch_input.routing.mc2_mask if self.global_bs == 0 else None,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -251,6 +265,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
|
||||
"moe_expert_num": self.moe_expert_num,
|
||||
"global_bs": self.global_bs,
|
||||
}
|
||||
if self.global_bs == 0:
|
||||
kwargs_mc2["x_active_mask"] = combine_metadata.mc2_mask
|
||||
|
||||
if combine_metadata.dispatch_with_quant:
|
||||
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
@@ -997,6 +997,62 @@ def is_hierarchical_communication_enabled():
|
||||
) or get_ascend_config().enable_mc2_hierarchy_comm
|
||||
|
||||
|
||||
def should_skip_allreduce_across_dp_group(vllm_config, is_draft_model: bool = False) -> bool:
|
||||
"""Decide whether to skip the all-reduce across the DP group.
|
||||
|
||||
Skipping is applicable for all dense models and for moe models only on ranks
|
||||
that act as KV consumers. We skip the DP all-reduce when either:
|
||||
- Both the prefill and decode communication methods are MC2 (or FUSED_MC2), or
|
||||
- Decode requires MC2 and ascend_config.recompute_scheduler_enable is True.
|
||||
|
||||
Skipping means each rank may have a different number of tokens, so MC2 needs
|
||||
a non-zero global_bs and must NOT receive mc2_mask.
|
||||
|
||||
Returns False when hierarchy comm is enabled because hierarchy requires
|
||||
global_bs=0 (uniform tokens), which is incompatible with skipping allreduce.
|
||||
"""
|
||||
if is_hierarchical_communication_enabled():
|
||||
return False
|
||||
|
||||
# For dense models, since we don't actually need dp communication, we simply skip it.
|
||||
# This usually happens when main model is moe while eagle draft model is dense.
|
||||
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
||||
if not is_context_moe_model:
|
||||
return True
|
||||
|
||||
# Only applicable to MoE models on KV consumer ranks.
|
||||
is_kv_consumer = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer
|
||||
if not is_kv_consumer:
|
||||
return False
|
||||
|
||||
from vllm_ascend.ascend_forward_context import select_moe_comm_method
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import MoECommType
|
||||
|
||||
def needs_mc2(n: int) -> bool:
|
||||
return select_moe_comm_method(n, vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
uniform_decode_query_len = 1 if not speculative_config else 1 + speculative_config.num_speculative_tokens
|
||||
decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0)
|
||||
max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
||||
|
||||
# Determine whether decode must use MC2. Use max cudagraph capture size
|
||||
# if available, otherwise use the maximal uniform decode token count.
|
||||
if compilation_config.cudagraph_capture_sizes:
|
||||
potential_max_tokens = compilation_config.max_cudagraph_capture_size
|
||||
else:
|
||||
potential_max_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
||||
|
||||
decode_must_use_mc2 = needs_mc2(potential_max_tokens)
|
||||
# For prefill, use the scheduler's max_num_batched_tokens for a single batch.
|
||||
prefill_must_use_mc2 = needs_mc2(scheduler_config.max_num_batched_tokens)
|
||||
# Skip all-reduce if decode requires MC2 and either prefill also
|
||||
# requires MC2 or recompute-based scheduler is enabled.
|
||||
return decode_must_use_mc2 and (prefill_must_use_mc2 or get_ascend_config().recompute_scheduler_enable)
|
||||
|
||||
|
||||
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
||||
if model_instance is None:
|
||||
return False
|
||||
|
||||
@@ -123,10 +123,9 @@ from vllm_ascend.utils import (
|
||||
enable_sp,
|
||||
enable_sp_by_pass,
|
||||
global_stream,
|
||||
is_drafter_moe_model,
|
||||
is_moe_model,
|
||||
lmhead_tp_enable,
|
||||
set_weight_prefetch_method,
|
||||
should_skip_allreduce_across_dp_group,
|
||||
)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
from vllm_ascend.worker.pcp_utils import PCPManager
|
||||
@@ -469,46 +468,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
and not self.model_config.enforce_eager
|
||||
)
|
||||
|
||||
def _skip_all_reduce_across_dp_group(self, is_draft_model=False) -> bool:
|
||||
"""
|
||||
Decide whether to skip the all-reduce across the data-parallel (DP) group.
|
||||
|
||||
Skipping is applicable for all dense models and for moe models only on ranks
|
||||
that act as KV consumers. We skip the DP all-reduce when either:
|
||||
- Both the prefill and decode communication methods are MC2 (or FUSED_MC2), or
|
||||
- Decode requires MC2 and ascend_config.recompute_scheduler_enable is True.
|
||||
"""
|
||||
# For dense models, since we don't actually need dp communication, we simply skip it.
|
||||
# This usually happens when main model is moe while eagle draft model is dense.
|
||||
is_context_moe_model = (
|
||||
is_drafter_moe_model(self.vllm_config) if is_draft_model else is_moe_model(self.vllm_config)
|
||||
)
|
||||
if not is_context_moe_model:
|
||||
return True
|
||||
|
||||
# Only applicable to MoE models on KV consumer ranks.
|
||||
if not self.is_kv_consumer:
|
||||
return False
|
||||
|
||||
def needs_mc2(num_tokens: int) -> bool:
|
||||
return select_moe_comm_method(num_tokens, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
|
||||
# Determine whether decode must use MC2. Use max cudagraph capture size
|
||||
# if available, otherwise use the maximal uniform decode token count.
|
||||
if self.compilation_config.cudagraph_capture_sizes:
|
||||
potential_max_tokens = self.compilation_config.max_cudagraph_capture_size
|
||||
else:
|
||||
potential_max_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||
decode_must_use_mc2 = needs_mc2(potential_max_tokens)
|
||||
|
||||
# For prefill, use the scheduler's max_num_batched_tokens for a single
|
||||
# batch.
|
||||
prefill_must_use_mc2 = needs_mc2(self.vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# Skip all-reduce if decode requires MC2 and either prefill also
|
||||
# requires MC2 or recompute-based scheduler is enabled.
|
||||
return decode_must_use_mc2 and (prefill_must_use_mc2 or self.ascend_config.recompute_scheduler_enable)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool = False, is_draft_model: bool = False
|
||||
) -> tuple[int, torch.Tensor | None, bool]:
|
||||
@@ -521,7 +480,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.dp_size == 1:
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
if self._skip_all_reduce_across_dp_group(is_draft_model):
|
||||
if should_skip_allreduce_across_dp_group(self.vllm_config, is_draft_model):
|
||||
num_tokens_after_padding = torch.tensor([num_tokens] * self.dp_size, device="cpu", dtype=torch.int32)
|
||||
return num_tokens, num_tokens_after_padding, with_prefill
|
||||
|
||||
@@ -1891,7 +1850,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.dp_size == 1:
|
||||
return False, None, cudagraph_mode
|
||||
|
||||
if self._skip_all_reduce_across_dp_group():
|
||||
if should_skip_allreduce_across_dp_group(self.vllm_config):
|
||||
num_tokens_after_padding = torch.tensor([num_tokens_padded] * self.dp_size, device="cpu", dtype=torch.int32)
|
||||
return False, num_tokens_after_padding, cudagraph_mode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user