diff --git a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py index 6c6ff263..3284e345 100644 --- a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py +++ b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py @@ -31,7 +31,8 @@ def mock_adaptor(): def test_generate_task_and_state_flow(mock_adaptor): - loader_obj = loader.D2DExpertWeightLoader() + with patch("vllm_ascend.eplb.core.eplb_device_transfer_loader.get_dynamic_eplb_group", return_value=None): + loader_obj = loader.D2DExpertWeightLoader() loader_obj.set_adator(mock_adaptor) with patch("torch.distributed.P2POp") as mock_p2p, \ @@ -52,7 +53,8 @@ def test_generate_task_and_state_flow(mock_adaptor): def test_asyn_transfer_and_update(mock_adaptor): - loader_obj = loader.D2DExpertWeightLoader() + with patch("vllm_ascend.eplb.core.eplb_device_transfer_loader.get_dynamic_eplb_group", return_value=None): + loader_obj = loader.D2DExpertWeightLoader() loader_obj.set_adator(mock_adaptor) loader_obj.comm_op_list = ["fake_op"] @@ -88,14 +90,16 @@ def test_asyn_transfer_and_update(mock_adaptor): def test_set_log2phy_map(mock_adaptor): - loader_obj = loader.D2DExpertWeightLoader() + with patch("vllm_ascend.eplb.core.eplb_device_transfer_loader.get_dynamic_eplb_group", return_value=None): + loader_obj = loader.D2DExpertWeightLoader() loader_obj.set_adator(mock_adaptor) loader_obj.set_log2phy_map({"a": 1}) assert loader_obj.updated_log2phy_map == {"a": 1} def test_invalid_state_asyn_update(mock_adaptor): - loader_obj = loader.D2DExpertWeightLoader() + with patch("vllm_ascend.eplb.core.eplb_device_transfer_loader.get_dynamic_eplb_group", return_value=None): + loader_obj = loader.D2DExpertWeightLoader() loader_obj.set_adator(mock_adaptor) loader_obj.state = loader.ExpertWeightUpdateState.WAITING diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index e4a7c399..5f1c37f1 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -275,7 +275,7 @@ def get_fc3_quant_x_group() -> GroupCoordinator: def get_dynamic_eplb_group() -> GroupCoordinator: - assert _DYNAMIC_EPLB is not None, "fc3 quant x group is not initialized" + assert _DYNAMIC_EPLB is not None, "Dynamic eplb group is not initialized" return _DYNAMIC_EPLB diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index 728a61f2..79321aca 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -19,6 +19,8 @@ from enum import Enum import torch.distributed as dist from vllm.logger import logger +from vllm_ascend.distributed.parallel_state import get_dynamic_eplb_group + class ExpertWeightUpdateState(Enum): WAITING = 0 # waiting for updated expert_map by EplbWorker @@ -35,6 +37,7 @@ class D2DExpertWeightLoader: self.state = ExpertWeightUpdateState.WAITING self.recv_expert_list = [] self.num_layers = 0 + self.comm_group = get_dynamic_eplb_group() def set_adator(self, eplb_adaptor): self.eplb_adaptor = eplb_adaptor @@ -53,12 +56,16 @@ class D2DExpertWeightLoader: dst_rank, global_expert_id_to_send = send_info local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item() for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]: - self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) + self.comm_op_list.append( + dist.P2POp(dist.isend, src_tensor, dst_rank, group=self.comm_group.device_group) + ) for buffer_tensor_id, recv_info in enumerate(expert_recv_info): recv_rank, global_expert_id_to_recv = recv_info for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]: - self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) + self.comm_op_list.append( + dist.P2POp(dist.irecv, buffer_tensor, recv_rank, group=self.comm_group.device_group) + ) local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item() self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id)) diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index fa06737d..285bd486 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -155,12 +155,12 @@ class EplbUpdator: for dst_rank in range(self.world_size): if dst_rank == self.rank_id: continue - comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) + comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank, group=self.comm_group.device_group)) for src_rank in range(self.world_size): if src_rank == self.rank_id: continue - comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank)) + comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank, group=self.comm_group.device_group)) if comm_op_list: reqs = dist.batch_isend_irecv(comm_op_list) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 458b95cd..1d341899 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -62,7 +62,7 @@ _CP_CHUNKEDPREFILL_COMM_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False _DEFAULT_BUFFER_SIZE = 200 _MIN_DP_BUFFER_SIZE = 50 -_DYNAMIC_EPLB_BUFFER_SIZE = 1 # num_experts * num_layers * 64 byte +_DYNAMIC_EPLB_BUFFER_SIZE = 100 _IS_MOE_MODEL = None _IS_DRAFTER_MOE_MODEL = None _IS_VL_MODEL = None