From 6b853f15fe69ba335d2745ebcf14a164d0bcc505 Mon Sep 17 00:00:00 2001 From: Yuxiao-Xu <664988918@qq.com> Date: Mon, 9 Jun 2025 19:28:11 +0800 Subject: [PATCH] Add static EPLB (#1116) ### What this PR does / why we need it? Add EPLB expert map import capabilities ### Does this PR introduce _any_ user-facing change? When importing the EPLB expert map you need import expert map file by vllm args additional_config ### How was this patch tested? 1.You need to collect expert hotness and generate an expert placement file based on the hotness and the EPLB algorithm, or you can directly use an existing expert placement table. 2.When launching vLLM, enable EC2 and pass the configuration via the command-line argument: --additional-config '{"expert_map_path": "/xxx/xxx/xx.json"} Co-authored-by: songshanhu07 <1763685535@qq.com> --------- Signed-off-by: songshanhu07 <1763685535@qq.com> Signed-off-by: Yuxiao-Xu <664988918@qq.com> Signed-off-by: wangxiyuan Co-authored-by: songshanhu07 <1763685535@qq.com> Co-authored-by: Xu Yuxiao Co-authored-by: wangxiyuan --- docs/source/user_guide/additional_config.md | 13 +-- vllm_ascend/ascend_config.py | 1 + vllm_ascend/ops/expert_load_balancer.py | 99 +++++++++++++++++++++ vllm_ascend/ops/fused_moe.py | 38 ++++++-- vllm_ascend/quantization/quant_config.py | 5 +- vllm_ascend/quantization/w8a8_dynamic.py | 54 +++++++---- 6 files changed, 179 insertions(+), 31 deletions(-) create mode 100644 vllm_ascend/ops/expert_load_balancer.py diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index a884bda..c6558f4 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -24,12 +24,13 @@ LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"}) The following table lists the additional configuration options available in vLLM Ascend: -| Name | Type | Default | Description | -| ---- | ---- | ------- | ----------- | -| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | -| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | -| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. | -| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. | +| Name | Type | Default | Description | +|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------| +| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | +| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | +| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. | +| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. | +| `expert_map_path` | str | None | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | The details of each config option are as follows: diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 065b7d0..0c072c3 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -38,6 +38,7 @@ class AscendConfig: self.expert_tensor_parallel_size = int( additional_config.get("expert_tensor_parallel_size", 0)) + self.expert_map_path = additional_config.get("expert_map_path", None) class TorchairGraphConfig: diff --git a/vllm_ascend/ops/expert_load_balancer.py b/vllm_ascend/ops/expert_load_balancer.py new file mode 100644 index 0000000..c6eec64 --- /dev/null +++ b/vllm_ascend/ops/expert_load_balancer.py @@ -0,0 +1,99 @@ +import json +import random +from typing import Dict, List + +import torch + + +class ExpertLoadBalancer(object): + + def __init__(self, expert_map_path, global_expert_num): + self.expert_map_path = expert_map_path + self.global_expert_num = global_expert_num + self.expert_map_tensor, self.layers_num, self.ranks_num = ( + self._expert_file_to_tensor()) + + def _expert_file_to_tensor(self): + with open(self.expert_map_path, "r") as f: + data = json.load(f) + layers_num = data["moe_layer_count"] + gpus_num = data["layer_list"][0]["device_count"] + + tensor_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data.append(device["device_expert"]) + tensor_data.append(device_data) + expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + return expert_map_tensor, layers_num, gpus_num + + def generate_index_dicts(self, tensor_2d): + dict_list = [] + current_idx = 0 + + for row in tensor_2d: + value_to_index = {} + for i in range(row.size(0)): + value = row[i].item() + value_to_index[value] = current_idx + i + dict_list.append(value_to_index) + current_idx += row.size(0) + + return dict_list + + def generate_expert_placement_map(self): + expert_placement_map = torch.full( + (self.layers_num, self.ranks_num, self.global_expert_num), + -1, + dtype=torch.int32, + ) + for layer_id in range(self.layers_num): + for gpu_id in range(self.ranks_num): + e_ids = self.expert_map_tensor[layer_id, gpu_id] + expert_placement_map[layer_id, gpu_id, + e_ids] = torch.arange(len(e_ids), + dtype=torch.int32) + return expert_placement_map + + def generate_log2phy_expert_map(self, layer_id): + concatenated = torch.flatten(self.expert_map_tensor[layer_id]) + rank_expert_to_global = self.generate_index_dicts( + self.expert_map_tensor[layer_id]) + result_dict: Dict[int, List[int]] = {} + for idx, value in enumerate(concatenated): + key = value.item() + if key not in result_dict: + result_dict[key] = [] + result_dict[key].append(idx) + + log2phy_map = torch.full((self.ranks_num, self.global_expert_num), + -1, + dtype=torch.int32) + for rank in range(self.ranks_num): + for key in result_dict: + indices_in_concat = result_dict[key] + if key in rank_expert_to_global[rank]: + log2phy_map[rank][key] = rank_expert_to_global[rank][key] + else: + chosen_index = random.choice(indices_in_concat) + log2phy_map[rank][key] = chosen_index + return log2phy_map + + def get_rank_placement_map(self, layer_id, rank_id): + expert_placement_map = self.generate_expert_placement_map() + layer_expert_map = expert_placement_map[layer_id] + rank_expert_map = layer_expert_map[rank_id].to( + torch.npu.current_device()) + rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item() + return rank_local_expert_num, rank_expert_map + + def get_rank_log2phy_map(self, layer_id, rank_id): + layer_log2phy_map = self.generate_log2phy_expert_map(layer_id) + return layer_log2phy_map[rank_id] + + def get_global_redundant_expert_num(self): + global_redundant_expert_num = ( + len(self.expert_map_tensor[0][0]) * self.ranks_num - + self.global_expert_num) + return global_redundant_expert_num diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c5f3178..25f3b05 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,6 +15,7 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py +import os from typing import Callable, List, Optional import torch @@ -34,6 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import \ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group +from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM @@ -956,6 +958,10 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): class AscendFusedMoE(FusedMoE): + # The moe_counter parameter is required during the initialization of EPLB + # to identify the current layer index within the MOE model. + moe_counter = -1 + def __init__( self, num_experts: int, # Global number of experts @@ -983,6 +989,9 @@ class AscendFusedMoE(FusedMoE): # fixme and make __init__() of AscendFusedMoE more clear super(FusedMoE, self).__init__() + AscendFusedMoE.moe_counter += 1 + self.moe_instance_id = AscendFusedMoE.moe_counter + if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -1016,16 +1025,33 @@ class AscendFusedMoE(FusedMoE): self.e_score_correction_bias = e_score_correction_bias self.expert_map = None self.activation = activation + self.log2phy = None + self.global_redundant_expert_num = 0 - # Create a tensor of size num_experts filled with -1 - self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) + ascend_config = get_ascend_config() + expert_map_path = ascend_config.expert_map_path + if expert_map_path and os.path.exists(expert_map_path): + # moe expert load balance + expert_load_balancer = ExpertLoadBalancer(expert_map_path, + self.global_num_experts) + self.local_num_experts, self.expert_map = \ + expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, + get_ep_group().rank_in_group) + self.log2phy = expert_load_balancer.get_rank_log2phy_map( + self.moe_instance_id, + get_ep_group().rank_in_group) + self.global_redundant_expert_num = \ + expert_load_balancer.get_global_redundant_expert_num() + else: + # Create a tensor of size num_experts filled with -1 + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, + get_ep_group().rank_in_group, self.global_num_experts) self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group - ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on self.enable_multistream_shared_expert = \ @@ -1122,6 +1148,8 @@ class AscendFusedMoE(FusedMoE): e_score_correction_bias=self.e_score_correction_bias, is_prefill=is_prefill, enable_force_load_balance=enable_force_load_balance, + log2phy=self.log2phy, + global_redundant_expert_num=self.global_redundant_expert_num, **kwargs) if self.enable_multistream_shared_expert and not is_prefill: diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index e43f25d..3567dba 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -323,13 +323,16 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, + log2phy: torch.Tensor = None, + global_redundant_expert_num=0, **kwargs, ) -> torch.Tensor: return self.quant_method.apply( layer, x, router_logits, top_k, renormalize, use_grouped_topk, global_num_experts, expert_map, topk_group, num_expert_group, custom_routing_function, scoring_func, e_score_correction_bias, - is_prefill, enable_force_load_balance, **kwargs) + is_prefill, enable_force_load_balance, log2phy, + global_redundant_expert_num, **kwargs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 68d70bc..2b7d57c 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -147,9 +147,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, moe_all_to_all_group_name: str = "", + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, **kwargs) -> torch.Tensor: + + topk_ids = log2phy[topk_ids] global_bs = 0 - moe_expert_num = len(expert_map) + moe_expert_num = len(expert_map) + global_redundant_expert_num # hidden_states = hidden_states.bfloat16() kwargs_mc2 = { "x": hidden_states, @@ -271,7 +275,10 @@ def fused_experts_with_all2all( top_k: int, expert_map: torch.Tensor = None, ep_group: GroupCoordinator = None, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, ): + topk_ids = log2phy[topk_ids] original_shape = hidden_states.shape if len(original_shape) == 3: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -281,7 +288,7 @@ def fused_experts_with_all2all( device = hidden_states.device if expert_map is not None: - global_num_experts = len(expert_map) + global_num_experts = len(expert_map) + global_redundant_expert_num local_num_experts = global_num_experts // ep_group.world_size row_idx_len = num_tokens * top_k row_idx = (torch.arange(0, @@ -341,13 +348,14 @@ def fused_experts_with_all2all( group_list_type = 0 # `hidden_states` will be disposed in the `apply_mlp` function - hidden_states = apply_mlp(hidden_states, - w1, - w1_scale, - w2, - w2_scale, - expert_tokens, - group_list_type=group_list_type) + hidden_states = apply_mlp( + hidden_states, + w1, + w1_scale, #17 + w2, + w2_scale, + expert_tokens, #16 + group_list_type=group_list_type) if expert_map is not None: resorted_idx = torch.argsort(sorted_idx) @@ -639,6 +647,8 @@ class AscendW8A8DynamicFusedMoEMethod: e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -693,6 +703,8 @@ class AscendW8A8DynamicFusedMoEMethod: top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, **kwargs) elif self.torchair_graph_enabled or self.ep_group.world_size == 1: return fused_experts(hidden_states=x, @@ -709,16 +721,20 @@ class AscendW8A8DynamicFusedMoEMethod: # according to tp_size before they are feed into fused_moe module. # Therefore, all2all is needed no matter how dp/tp is set so as to # dispatch/combine tokens. - return fused_experts_with_all2all(hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=self.ep_group) + return fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + ) def process_weights_after_loading(self, layer): if self.transpose_weight: