diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index f53d023c..965794ee 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -24,6 +24,8 @@ _SHARD_WEIGHT: GroupCoordinator | None = None _P_TP: GroupCoordinator | None = None +_DYNAMIC_EPLB: GroupCoordinator | None = None + def init_ascend_model_parallel( parallel_config: ParallelConfig, @@ -85,6 +87,12 @@ def init_ascend_model_parallel( _MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2") + if get_ascend_config().eplb_config.dynamic_eplb: + global _DYNAMIC_EPLB + _DYNAMIC_EPLB = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb" + ) + # Initialize fine-grained TP process groups on Ascend for four components: # 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`) # 2. O Proj: attention output projection (`oproj_tensor_parallel_size`) @@ -265,6 +273,11 @@ def get_fc3_quant_x_group() -> GroupCoordinator: return _FC3_QUANT_X +def get_dynamic_eplb_group() -> GroupCoordinator: + assert _DYNAMIC_EPLB is not None, "fc3 quant x group is not initialized" + return _DYNAMIC_EPLB + + def destroy_ascend_model_parallel(): global _MC2 if _MC2: @@ -315,3 +328,8 @@ def destroy_ascend_model_parallel(): if _FC3_QUANT_X: _FC3_QUANT_X.destroy() _FC3_QUANT_X = None + + global _DYNAMIC_EPLB + if _DYNAMIC_EPLB: + _DYNAMIC_EPLB.destroy() + _DYNAMIC_EPLB = None diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 536786c2..55fab81f 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -21,6 +21,7 @@ import torch.distributed as dist import vllm.envs as envs from vllm.logger import logger +from vllm_ascend.distributed.parallel_state import get_dynamic_eplb_group from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader from vllm_ascend.eplb.core.eplb_worker import EplbProcess @@ -34,6 +35,7 @@ class EplbUpdator: self.eplb_process = eplb_process self.shared_dict = self.eplb_process.shared_dict self.moe_imbalance_dict: dict[int, float] = {} + self.comm_group = get_dynamic_eplb_group() def set_adaptor(self, adaptor: VllmEplbAdaptor): self.adaptor = adaptor @@ -41,8 +43,6 @@ class EplbUpdator: local_load = self.adaptor.get_rank_expert_workload() self.world_size = dist.get_world_size() self.device = local_load.device - shape = (self.world_size, *local_load.shape) - self._gather_buffer = torch.empty(shape, dtype=local_load.dtype, device=self.device) self.eplb_loader.num_layers = self.adaptor.num_dense_layers + self.adaptor.num_moe_layers def init_eplb(self, expert_map_path, process): @@ -134,9 +134,8 @@ class EplbUpdator: def compute_and_set_moe_load(self): local_load = self.adaptor.get_rank_expert_workload() - dist.all_gather_into_tensor(self._gather_buffer, local_load) + moe_load = self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]) - moe_load = self._gather_buffer.permute(1, 0, 2) self.shared_dict["moe_load"] = moe_load.cpu() logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") @@ -183,17 +182,16 @@ class EplbUpdator: self.compute_and_set_moe_load() src_tensor = torch.empty((1,), device=self.device) - self_rank = dist.get_rank() comm_op_list = [] for dst_rank in range(self.world_size): - if dst_rank == self_rank: + if dst_rank == self.rank_id: continue comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) for src_rank in range(self.world_size): - if src_rank == self_rank: + if src_rank == self.rank_id: continue comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank)) if comm_op_list: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 41f1a169..a37d46b9 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -62,6 +62,7 @@ _CP_CHUNKEDPREFILL_COMM_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False _DEFAULT_BUFFER_SIZE = 200 _MIN_DP_BUFFER_SIZE = 50 +_DYNAMIC_EPLB_BUFFER_SIZE = 1 # num_experts * num_layers * 64 byte _IS_MOE_MODEL = None _IS_DRAFTER_MOE_MODEL = None _IS_VL_MODEL = None @@ -907,6 +908,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> dict | None: return None hccl_config_map = { "dp": {"hccl_buffer_size": calculate_dp_buffer_size()}, + "dynamic_eplb": {"hccl_buffer_size": _DYNAMIC_EPLB_BUFFER_SIZE}, } return hccl_config_map.get(group_name, get_default_buffer_config())