[BugFix]Support redundant experts in EPLB (#3473)
This PR adds support for redundant experts in the EPLB. Key points: - Use global_num_experts = num_experts + num_redundant_experts consistently. - Backward compatible when num_redundant_experts=0. Tested On a 16-rank setup (W8A8) with static EPLB and expert_map_path, verifying router logits shape and successful requests. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: yechao237 <yechao20180411@gmail.com>
This commit is contained in:
@@ -269,7 +269,7 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
|
||||
@@ -246,7 +246,7 @@ def torchair_fused_experts_with_mc2(
|
||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
moe_expert_num = len(expert_map)
|
||||
else:
|
||||
moe_expert_num = global_redundant_expert_num
|
||||
# hidden_states = hidden_states.bfloat16()
|
||||
@@ -431,7 +431,7 @@ def torchair_fused_experts_with_all2all(
|
||||
|
||||
if expert_map is not None:
|
||||
assert ep_group is not None, "ep_group must be provided when expert_map is given"
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
global_num_experts = len(expert_map)
|
||||
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
|
||||
hidden_states,
|
||||
@@ -929,9 +929,9 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if is_deepseek_v3_r1:
|
||||
|
||||
Reference in New Issue
Block a user