Add static EPLB (#1116)
### What this PR does / why we need it?
Add EPLB expert map import capabilities
### Does this PR introduce _any_ user-facing change?
When importing the EPLB expert map you need import expert map file by
vllm args additional_config
### How was this patch tested?
1.You need to collect expert hotness and generate an expert placement
file based on the hotness and the EPLB algorithm, or you can directly
use an existing expert placement table.
2.When launching vLLM, enable EC2 and pass the configuration via the
command-line argument:
--additional-config '{"expert_map_path": "/xxx/xxx/xx.json"}
Co-authored-by: songshanhu07 <1763685535@qq.com>
---------
Signed-off-by: songshanhu07 <1763685535@qq.com>
Signed-off-by: Yuxiao-Xu <664988918@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: songshanhu07 <1763685535@qq.com>
Co-authored-by: Xu Yuxiao <xuyuxiao2@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -323,13 +323,16 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num=0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||
is_prefill, enable_force_load_balance, **kwargs)
|
||||
is_prefill, enable_force_load_balance, log2phy,
|
||||
global_redundant_expert_num, **kwargs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
|
||||
@@ -147,9 +147,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: str = "",
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
topk_ids = log2phy[topk_ids]
|
||||
global_bs = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
# hidden_states = hidden_states.bfloat16()
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
@@ -271,7 +275,10 @@ def fused_experts_with_all2all(
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
):
|
||||
topk_ids = log2phy[topk_ids]
|
||||
original_shape = hidden_states.shape
|
||||
if len(original_shape) == 3:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
@@ -281,7 +288,7 @@ def fused_experts_with_all2all(
|
||||
device = hidden_states.device
|
||||
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map)
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
@@ -341,13 +348,14 @@ def fused_experts_with_all2all(
|
||||
group_list_type = 0
|
||||
|
||||
# `hidden_states` will be disposed in the `apply_mlp` function
|
||||
hidden_states = apply_mlp(hidden_states,
|
||||
w1,
|
||||
w1_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_tokens,
|
||||
group_list_type=group_list_type)
|
||||
hidden_states = apply_mlp(
|
||||
hidden_states,
|
||||
w1,
|
||||
w1_scale, #17
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_tokens, #16
|
||||
group_list_type=group_list_type)
|
||||
|
||||
if expert_map is not None:
|
||||
resorted_idx = torch.argsort(sorted_idx)
|
||||
@@ -639,6 +647,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
@@ -693,6 +703,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
**kwargs)
|
||||
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
@@ -709,16 +721,20 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
# according to tp_size before they are feed into fused_moe module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return fused_experts_with_all2all(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group)
|
||||
return fused_experts_with_all2all(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
|
||||
Reference in New Issue
Block a user