[EPLB][Bugfix]Reduce unnecessary video memory usage (#6020)

### What this PR does / why we need it?
1.Incorporate the warm up of the EPLB into the profile run.
2.Reusing the same gather buffer

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
qwen3-235b aime baseline
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

eplb The OOM issue does not occur.
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
LI SHENGYONG
2026-01-23 14:21:13 +08:00
committed by GitHub
parent 749e24f81e
commit 8210a62a44
4 changed files with 20 additions and 30 deletions

View File

@@ -21,6 +21,7 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import logger
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
@@ -35,9 +36,16 @@ class EplbUpdator:
self.shared_dict = self.eplb_process.shared_dict
self.moe_imbalance_dict: dict[int, float] = {}
def set_adaptor(self, adaptor):
def set_adaptor(self, adaptor: VllmEplbAdaptor):
self.adaptor = adaptor
self.num_moe_layers = self.adaptor.num_moe_layers
local_load = self.adaptor.get_rank_expert_workload()
self.world_size = dist.get_world_size()
self.device = local_load.device
shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape,
dtype=local_load.dtype,
device=self.device)
def init_eplb(self, expert_map_path, process):
self.rank_id = dist.get_rank()
@@ -122,7 +130,7 @@ class EplbUpdator:
def forward_end(self):
if self.wakeup_eplb_worker_flag():
self.compute_and_set_moe_load(is_clear=True)
self.compute_and_set_moe_load()
self.wakeup_eplb_worker()
if self.update_expert_weight_flag(
@@ -131,34 +139,17 @@ class EplbUpdator:
self.update_iteration()
def compute_and_set_moe_load(self, is_clear=False):
def compute_and_set_moe_load(self):
local_load = self.adaptor.get_rank_expert_workload()
dist.all_gather_into_tensor(self._gather_buffer, local_load)
self._gather_buffer = None
if dist.is_initialized():
self.world_size = dist.get_world_size()
self.device = local_load.device
if self._gather_buffer is None:
shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape,
dtype=local_load.dtype,
device=self.device)
moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
dist.all_gather_into_tensor(self._gather_buffer, local_load)
moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
else:
moe_load = local_load.unsqueeze(1)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
if dist.is_initialized() and dist.get_rank() == 0:
if dist.get_rank() == 0:
self.compute_moe_imbalance(moe_load)
self.summarize_moe_imbalance()