[BugFix]Support redundant experts in EPLB (#3473)
This PR adds support for redundant experts in the EPLB. Key points: - Use global_num_experts = num_experts + num_redundant_experts consistently. - Backward compatible when num_redundant_experts=0. Tested On a 16-rank setup (W8A8) with static EPLB and expert_map_path, verifying router logits shape and successful requests. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: yechao237 <yechao20180411@gmail.com>
This commit is contained in:
@@ -40,14 +40,6 @@ def determine_default_expert_map(global_expert_num, world_size, rank_id,
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - rank_id * local_num_experts
|
||||
|
||||
if isinstance(global_redundant_expert_num,
|
||||
int) and rank_id < global_redundant_expert_num:
|
||||
local_count += 1
|
||||
if end < global_expert_num:
|
||||
end += 1
|
||||
else:
|
||||
start -= 1
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map[start:end] = local_ids
|
||||
@@ -118,14 +110,6 @@ def determine_default_log2phy_map(global_expert_num, world_size, rank_id,
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - r * local_num_experts
|
||||
|
||||
if isinstance(global_redundant_expert_num,
|
||||
int) and rank_id < global_redundant_expert_num:
|
||||
local_count += 1
|
||||
if end < global_expert_num:
|
||||
end += 1
|
||||
else:
|
||||
start -= 1
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map_all[r, start:end] = local_ids
|
||||
|
||||
@@ -20,6 +20,8 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -176,7 +178,8 @@ def _select_experts_with_fusion_ops(
|
||||
|
||||
topk_weights, topk_ids = None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
|
||||
@@ -123,10 +123,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
):
|
||||
if self.with_quant:
|
||||
quant_mode = 2
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
else:
|
||||
moe_expert_num = global_redundant_expert_num
|
||||
moe_expert_num = len(expert_map)
|
||||
else:
|
||||
quant_mode = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
|
||||
@@ -263,7 +263,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
topk_weights, topk_ids = select_experts(
|
||||
|
||||
@@ -263,7 +263,7 @@ class AscendW8A8FusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
@@ -203,7 +203,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
@@ -856,8 +856,9 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
global_redundant_expert_num = get_ascend_config(
|
||||
).init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
|
||||
@@ -269,7 +269,7 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
|
||||
@@ -246,7 +246,7 @@ def torchair_fused_experts_with_mc2(
|
||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
moe_expert_num = len(expert_map)
|
||||
else:
|
||||
moe_expert_num = global_redundant_expert_num
|
||||
# hidden_states = hidden_states.bfloat16()
|
||||
@@ -431,7 +431,7 @@ def torchair_fused_experts_with_all2all(
|
||||
|
||||
if expert_map is not None:
|
||||
assert ep_group is not None, "ep_group must be provided when expert_map is given"
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
global_num_experts = len(expert_map)
|
||||
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
|
||||
hidden_states,
|
||||
@@ -929,9 +929,9 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if is_deepseek_v3_r1:
|
||||
|
||||
Reference in New Issue
Block a user