From 1e05c4908f31737bc4eef865a9f351d030a77c9d Mon Sep 17 00:00:00 2001
From: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com>
Date: Fri, 20 Mar 2026 12:25:58 +0800
Subject: [PATCH] [EPLB] Reduce the memory used for batch_isend_irecv (#7344)
### What this PR does / why we need it?
#6729 seems to reduce the NPU memory usage of eplb, but actually moves
the buffer allocation of dist.all_gather_into_tensor to
dist.batch_isend_irecv. Therefore, the overall NPU memory usage is not
reduced. This PR completely reduces the memory usage in this part.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Remaining memory of each rank before the repair.
Remaining memory of each rank after the repair.
Close EPLB.
Memory of weights for each rank.
Estimated memory for EPLB: 15.68 / 48 (layer_num) + 2 * 0.02 = 0.35 GB
- vLLM version: v0.17.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
---
.../ut/eplb/core/test_eplb_device_transfer_loader.py | 12 ++++++++----
vllm_ascend/distributed/parallel_state.py | 2 +-
vllm_ascend/eplb/core/eplb_device_transfer_loader.py | 11 +++++++++--
vllm_ascend/eplb/eplb_updator.py | 4 ++--
vllm_ascend/utils.py | 2 +-
5 files changed, 21 insertions(+), 10 deletions(-)
diff --git a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py
index 6c6ff263..3284e345 100644
--- a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py
+++ b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py
@@ -31,7 +31,8 @@ def mock_adaptor():
def test_generate_task_and_state_flow(mock_adaptor):
- loader_obj = loader.D2DExpertWeightLoader()
+ with patch("vllm_ascend.eplb.core.eplb_device_transfer_loader.get_dynamic_eplb_group", return_value=None):
+ loader_obj = loader.D2DExpertWeightLoader()
loader_obj.set_adator(mock_adaptor)
with patch("torch.distributed.P2POp") as mock_p2p, \
@@ -52,7 +53,8 @@ def test_generate_task_and_state_flow(mock_adaptor):
def test_asyn_transfer_and_update(mock_adaptor):
- loader_obj = loader.D2DExpertWeightLoader()
+ with patch("vllm_ascend.eplb.core.eplb_device_transfer_loader.get_dynamic_eplb_group", return_value=None):
+ loader_obj = loader.D2DExpertWeightLoader()
loader_obj.set_adator(mock_adaptor)
loader_obj.comm_op_list = ["fake_op"]
@@ -88,14 +90,16 @@ def test_asyn_transfer_and_update(mock_adaptor):
def test_set_log2phy_map(mock_adaptor):
- loader_obj = loader.D2DExpertWeightLoader()
+ with patch("vllm_ascend.eplb.core.eplb_device_transfer_loader.get_dynamic_eplb_group", return_value=None):
+ loader_obj = loader.D2DExpertWeightLoader()
loader_obj.set_adator(mock_adaptor)
loader_obj.set_log2phy_map({"a": 1})
assert loader_obj.updated_log2phy_map == {"a": 1}
def test_invalid_state_asyn_update(mock_adaptor):
- loader_obj = loader.D2DExpertWeightLoader()
+ with patch("vllm_ascend.eplb.core.eplb_device_transfer_loader.get_dynamic_eplb_group", return_value=None):
+ loader_obj = loader.D2DExpertWeightLoader()
loader_obj.set_adator(mock_adaptor)
loader_obj.state = loader.ExpertWeightUpdateState.WAITING
diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py
index e4a7c399..5f1c37f1 100644
--- a/vllm_ascend/distributed/parallel_state.py
+++ b/vllm_ascend/distributed/parallel_state.py
@@ -275,7 +275,7 @@ def get_fc3_quant_x_group() -> GroupCoordinator:
def get_dynamic_eplb_group() -> GroupCoordinator:
- assert _DYNAMIC_EPLB is not None, "fc3 quant x group is not initialized"
+ assert _DYNAMIC_EPLB is not None, "Dynamic eplb group is not initialized"
return _DYNAMIC_EPLB
diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py
index 728a61f2..79321aca 100644
--- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py
+++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py
@@ -19,6 +19,8 @@ from enum import Enum
import torch.distributed as dist
from vllm.logger import logger
+from vllm_ascend.distributed.parallel_state import get_dynamic_eplb_group
+
class ExpertWeightUpdateState(Enum):
WAITING = 0 # waiting for updated expert_map by EplbWorker
@@ -35,6 +37,7 @@ class D2DExpertWeightLoader:
self.state = ExpertWeightUpdateState.WAITING
self.recv_expert_list = []
self.num_layers = 0
+ self.comm_group = get_dynamic_eplb_group()
def set_adator(self, eplb_adaptor):
self.eplb_adaptor = eplb_adaptor
@@ -53,12 +56,16 @@ class D2DExpertWeightLoader:
dst_rank, global_expert_id_to_send = send_info
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item()
for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]:
- self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
+ self.comm_op_list.append(
+ dist.P2POp(dist.isend, src_tensor, dst_rank, group=self.comm_group.device_group)
+ )
for buffer_tensor_id, recv_info in enumerate(expert_recv_info):
recv_rank, global_expert_id_to_recv = recv_info
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]:
- self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
+ self.comm_op_list.append(
+ dist.P2POp(dist.irecv, buffer_tensor, recv_rank, group=self.comm_group.device_group)
+ )
local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item()
self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id))
diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py
index fa06737d..285bd486 100644
--- a/vllm_ascend/eplb/eplb_updator.py
+++ b/vllm_ascend/eplb/eplb_updator.py
@@ -155,12 +155,12 @@ class EplbUpdator:
for dst_rank in range(self.world_size):
if dst_rank == self.rank_id:
continue
- comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
+ comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank, group=self.comm_group.device_group))
for src_rank in range(self.world_size):
if src_rank == self.rank_id:
continue
- comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank))
+ comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank, group=self.comm_group.device_group))
if comm_op_list:
reqs = dist.batch_isend_irecv(comm_op_list)
diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py
index 458b95cd..1d341899 100644
--- a/vllm_ascend/utils.py
+++ b/vllm_ascend/utils.py
@@ -62,7 +62,7 @@ _CP_CHUNKEDPREFILL_COMM_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
_DEFAULT_BUFFER_SIZE = 200
_MIN_DP_BUFFER_SIZE = 50
-_DYNAMIC_EPLB_BUFFER_SIZE = 1 # num_experts * num_layers * 64 byte
+_DYNAMIC_EPLB_BUFFER_SIZE = 100
_IS_MOE_MODEL = None
_IS_DRAFTER_MOE_MODEL = None
_IS_VL_MODEL = None