diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 2f2de07..9efb37a 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -53,6 +53,7 @@ class AscendConfig: # Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config self.expert_map_path = additional_config.get("expert_map_path", None) + self.eplb_policy_type = additional_config.get("eplb_policy_type", 1) self.expert_map_record_path = additional_config.get( "expert_map_record_path", None) # Provide path to export expert map diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index cd460f8..15dbbd0 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -73,8 +73,12 @@ class EplbWorker: new_expert_maps = self.local2global(new_placement) self.update_expert_map(new_expert_maps) - update_info = self.compose_expert_update_info_greedy( - new_expert_maps, self.old_expert_maps) + if self.policy_type == 2: + update_info = self.compose_expert_update_info_bipartite( + new_expert_maps, self.old_expert_maps) + else: + update_info = self.compose_expert_update_info_greedy( + new_expert_maps, self.old_expert_maps) self.old_expert_maps = new_expert_maps logger.info("EPLB Process compute complete") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index cefad46..7bddae0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -470,6 +470,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.dynamic_eplb = self.ascend_config.dynamic_eplb if self.dynamic_eplb: self.is_eplb_warmuped = False + self.policy_type = self.ascend_config.eplb_policy_type self.eplb_loader = D2DExpertWeightLoader() self.manager = Manager() self.shared_dict = self.manager.dict({ @@ -478,7 +479,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): "expert_maps": None }) self.eplb_process = EplbProcess(shared_dict=self.shared_dict, - policy_type=1, + policy_type=self.policy_type, enable_d2d=True) self.process = self.eplb_process._launch_process() ascend_config = get_ascend_config()