Add static EPLB (#1116)

### What this PR does / why we need it?
   Add EPLB expert map import capabilities
### Does this PR introduce _any_ user-facing change?
When importing the EPLB expert map you need import expert map file by
vllm args additional_config
### How was this patch tested?
1.You need to collect expert hotness and generate an expert placement
file based on the hotness and the EPLB algorithm, or you can directly
use an existing expert placement table.
2.When launching vLLM, enable EC2 and pass the configuration via the
command-line argument:
      --additional-config '{"expert_map_path": "/xxx/xxx/xx.json"}
Co-authored-by: songshanhu07 <1763685535@qq.com>

---------

Signed-off-by: songshanhu07 <1763685535@qq.com>
Signed-off-by: Yuxiao-Xu <664988918@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: songshanhu07 <1763685535@qq.com>
Co-authored-by: Xu Yuxiao <xuyuxiao2@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Yuxiao-Xu
2025-06-09 19:28:11 +08:00
committed by GitHub
parent cb341c7bcd
commit 6b853f15fe
6 changed files with 179 additions and 31 deletions

View File

@@ -323,13 +323,16 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
global_redundant_expert_num=0,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance, **kwargs)
is_prefill, enable_force_load_balance, log2phy,
global_redundant_expert_num, **kwargs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):

View File

@@ -147,9 +147,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
**kwargs) -> torch.Tensor:
topk_ids = log2phy[topk_ids]
global_bs = 0
moe_expert_num = len(expert_map)
moe_expert_num = len(expert_map) + global_redundant_expert_num
# hidden_states = hidden_states.bfloat16()
kwargs_mc2 = {
"x": hidden_states,
@@ -271,7 +275,10 @@ def fused_experts_with_all2all(
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
):
topk_ids = log2phy[topk_ids]
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
@@ -281,7 +288,7 @@ def fused_experts_with_all2all(
device = hidden_states.device
if expert_map is not None:
global_num_experts = len(expert_map)
global_num_experts = len(expert_map) + global_redundant_expert_num
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
@@ -341,13 +348,14 @@ def fused_experts_with_all2all(
group_list_type = 0
# `hidden_states` will be disposed in the `apply_mlp` function
hidden_states = apply_mlp(hidden_states,
w1,
w1_scale,
w2,
w2_scale,
expert_tokens,
group_list_type=group_list_type)
hidden_states = apply_mlp(
hidden_states,
w1,
w1_scale, #17
w2,
w2_scale,
expert_tokens, #16
group_list_type=group_list_type)
if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
@@ -639,6 +647,8 @@ class AscendW8A8DynamicFusedMoEMethod:
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
@@ -693,6 +703,8 @@ class AscendW8A8DynamicFusedMoEMethod:
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
**kwargs)
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
return fused_experts(hidden_states=x,
@@ -709,16 +721,20 @@ class AscendW8A8DynamicFusedMoEMethod:
# according to tp_size before they are feed into fused_moe module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return fused_experts_with_all2all(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group)
return fused_experts_with_all2all(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
)
def process_weights_after_loading(self, layer):
if self.transpose_weight: