diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index ce1c3d73..ba24c88c 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -60,7 +60,6 @@ class D2DExpertWeightLoader: layer_id][global_expert_id_to_send].item() for src_tensor in self.eplb_adaptor.expert_param_per_layer[ layer_id][local_expert_id]: - src_tensor = src_tensor.clone() self.comm_op_list.append( dist.P2POp(dist.isend, src_tensor, dst_rank)) diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index b76a4bbd..cf7cece4 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -21,6 +21,7 @@ import torch.distributed as dist import vllm.envs as envs from vllm.logger import logger +from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_worker import EplbProcess @@ -35,9 +36,16 @@ class EplbUpdator: self.shared_dict = self.eplb_process.shared_dict self.moe_imbalance_dict: dict[int, float] = {} - def set_adaptor(self, adaptor): + def set_adaptor(self, adaptor: VllmEplbAdaptor): self.adaptor = adaptor self.num_moe_layers = self.adaptor.num_moe_layers + local_load = self.adaptor.get_rank_expert_workload() + self.world_size = dist.get_world_size() + self.device = local_load.device + shape = (self.world_size, *local_load.shape) + self._gather_buffer = torch.empty(shape, + dtype=local_load.dtype, + device=self.device) def init_eplb(self, expert_map_path, process): self.rank_id = dist.get_rank() @@ -122,7 +130,7 @@ class EplbUpdator: def forward_end(self): if self.wakeup_eplb_worker_flag(): - self.compute_and_set_moe_load(is_clear=True) + self.compute_and_set_moe_load() self.wakeup_eplb_worker() if self.update_expert_weight_flag( @@ -131,34 +139,17 @@ class EplbUpdator: self.update_iteration() - def compute_and_set_moe_load(self, is_clear=False): + def compute_and_set_moe_load(self): local_load = self.adaptor.get_rank_expert_workload() + dist.all_gather_into_tensor(self._gather_buffer, local_load) - self._gather_buffer = None - if dist.is_initialized(): - self.world_size = dist.get_world_size() - self.device = local_load.device - if self._gather_buffer is None: - shape = (self.world_size, *local_load.shape) - self._gather_buffer = torch.empty(shape, - dtype=local_load.dtype, - device=self.device) + moe_load = self._gather_buffer.permute(1, 0, 2) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug( + f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" + ) - dist.all_gather_into_tensor(self._gather_buffer, local_load) - - moe_load = self._gather_buffer.permute(1, 0, 2) - self.shared_dict["moe_load"] = moe_load.cpu() - logger.debug( - f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" - ) - else: - moe_load = local_load.unsqueeze(1) - self.shared_dict["moe_load"] = moe_load.cpu() - logger.debug( - f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" - ) - - if dist.is_initialized() and dist.get_rank() == 0: + if dist.get_rank() == 0: self.compute_moe_imbalance(moe_load) self.summarize_moe_imbalance() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index cd728982..90745515 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2266,7 +2266,7 @@ class NPUModelRunner(GPUModelRunner): is_profile=is_profile) if is_profile and self.dynamic_eplb: self.model.clear_all_moe_loads() - if not is_profile and self.dynamic_eplb: + if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.forward_end() return hidden_states, hidden_states @@ -2293,6 +2293,7 @@ class NPUModelRunner(GPUModelRunner): return output def profile_run(self) -> None: + self.eplb_warmup() mc2_tokens_capacity = get_mc2_tokens_capacity() if self.max_num_tokens > mc2_tokens_capacity and \ select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}: diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index f697f928..b8f257cd 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -366,7 +366,6 @@ class NPUWorker(WorkerBase): def compile_or_warm_up_model(self) -> None: # Note: need to adapt for graph mode. - self.model_runner.eplb_warmup() warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy() if not self.model_config.enforce_eager: