Files
xc-llm-ascend/vllm_ascend/eplb/eplb_updator.py
LI SHENGYONG 1e05c4908f [EPLB] Reduce the memory used for batch_isend_irecv (#7344)
### What this PR does / why we need it?

#6729 seems to reduce the NPU memory usage of eplb, but actually moves
the buffer allocation of dist.all_gather_into_tensor to
dist.batch_isend_irecv. Therefore, the overall NPU memory usage is not
reduced. This PR completely reduces the memory usage in this part.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
Remaining memory of each rank before the repair.
<img width="649" height="99" alt="image"
src="https://github.com/user-attachments/assets/52a67592-e0e8-4f9a-b194-b84cb848c598"
/>

Remaining memory of each rank after the repair.
<img width="641" height="99" alt="image"
src="https://github.com/user-attachments/assets/0bc2e67c-f328-4dea-98af-d7a459fb4876"
/>

Close EPLB.
<img width="543" height="45" alt="image"
src="https://github.com/user-attachments/assets/6dcba19d-4401-44b8-a6d3-c7b35ee983c7"
/>

Memory of weights for each rank.
<img width="648" height="46" alt="image"
src="https://github.com/user-attachments/assets/4db2fd04-98a0-4d26-a026-2e8287102b99"
/>

Estimated memory for EPLB: 15.68  / 48 (layer_num) + 2 * 0.02 = 0.35 GB


- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
2026-03-20 12:25:58 +08:00

178 lines
7.1 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator.
import numpy
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import logger
from vllm_ascend.distributed.parallel_state import get_dynamic_eplb_group
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
class EplbUpdator:
def __init__(self, eplb_config, loader: D2DExpertWeightLoader, eplb_process: EplbProcess, process):
self.eplb_config = eplb_config
self.multi_stage = eplb_config.eplb_policy_type == 3
self.init_eplb(self.eplb_config.expert_map_path, process)
self.eplb_loader = loader
self.eplb_process = eplb_process
self.shared_dict = self.eplb_process.shared_dict
self.comm_group = get_dynamic_eplb_group()
def set_adaptor(self, adaptor: VllmEplbAdaptor):
self.adaptor = adaptor
self.num_moe_layers = self.adaptor.num_moe_layers
local_load = self.adaptor.get_rank_expert_workload()
self.world_size = dist.get_world_size()
self.device = local_load.device
self.eplb_loader.num_layers = self.adaptor.num_dense_layers + self.adaptor.num_moe_layers
def init_eplb(self, expert_map_path, process):
self.rank_id = dist.get_rank()
self.num_expert_load_gather = 10
self.periodic_load_gather = True
self.expert_heat_collection_interval: torch.int64 = self.eplb_config.expert_heat_collection_interval
self.expert_map_path = expert_map_path
self.expert_map_record_path = self.eplb_config.expert_map_record_path
try:
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
self.num_expert_load_gather = self.expert_heat_collection_interval
self.periodic_load_gather = False
except Exception:
self.num_expert_load_gather = self.expert_heat_collection_interval
self.periodic_load_gather = False
self.reqs = []
self.update_info_all = []
self.cur_iterations: torch.int64 = 0
self.algorithm_execution_interval: torch.int64 = self.eplb_config.algorithm_execution_interval
self.process = process
logger.info(f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
def update_iteration(self):
self.cur_iterations += 1
if self.cur_iterations == (
self.expert_heat_collection_interval + self.algorithm_execution_interval + self.num_moe_layers
):
if self.expert_map_record_path is not None:
self.adaptor._export_tensor_to_file(self.shared_dict["expert_maps"], self.expert_map_record_path)
self.adaptor.model.clear_all_moe_loads()
self.cur_iterations = 0
def get_update_info_flag(self):
return self.cur_iterations == (self.expert_heat_collection_interval + self.algorithm_execution_interval - 1)
def wakeup_eplb_worker_flag(self):
return self.cur_iterations == (self.expert_heat_collection_interval - 1)
def update_expert_weight_flag(self):
weight_update_counter = self.cur_iterations - (
self.expert_heat_collection_interval + self.algorithm_execution_interval
)
return weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers
def wakeup_eplb_worker(self):
self.eplb_process.planner_q.put(1)
def forward_before(self):
# Batch after eplb process being triggered, get update info provided by eplb process
if self.get_update_info_flag():
self.update_info_all = self.eplb_process.block_update_q.get()
if self.update_expert_weight_flag():
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(
0
)
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map))
self.eplb_loader.generate_expert_d2d_transfer_task(
expert_send_info,
expert_recv_info,
updated_expert_map_this_rank,
layer_id + self.adaptor.num_dense_layers,
)
# set asynchronous stream for d2d expert weight update
self.reqs = []
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
def forward_end(self):
if self.wakeup_eplb_worker_flag():
self.compute_and_set_moe_load()
self.wakeup_eplb_worker()
if self.update_expert_weight_flag() and self.expert_map_record_path is None:
self.eplb_loader.update_expert_map_and_weight(self.reqs)
self.update_iteration()
def compute_and_set_moe_load(self):
local_load = self.adaptor.get_rank_expert_workload()
moe_load = (
self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu()
)
if self.multi_stage:
moe_load = moe_load.permute(2, 0, 1, 3)
self.shared_dict["moe_load"] = moe_load
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
return moe_load
def warm_up_eplb(self):
self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map()
self.compute_and_set_moe_load()
src_tensor = torch.empty((1,), device=self.device)
comm_op_list = []
for dst_rank in range(self.world_size):
if dst_rank == self.rank_id:
continue
comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank, group=self.comm_group.device_group))
for src_rank in range(self.world_size):
if src_rank == self.rank_id:
continue
comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank, group=self.comm_group.device_group))
if comm_op_list:
reqs = dist.batch_isend_irecv(comm_op_list)
for req in reqs:
req.wait()
def shutdown(self):
"""
Clean up the EPLB process.
"""
if self.process.is_alive():
self.process.terminate()
self.process.join()
logger.info("[ModelRunner] EPLB process terminated")