[EPLB] Reduce the memory used for heat aggregation (#6729)
### What this PR does / why we need it?
If dist.all_gather is used directly, 2 x HCCL_BUFFSIZE memory will be
consumed, but the actual memory required for hotspot aggregation is less
than 1 MB. Therefore, a separate small communication domain is created
for it.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Original:

Current:

- vLLM version: v0.15.0
- vLLM main:
9562912cea
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
@@ -24,6 +24,8 @@ _SHARD_WEIGHT: GroupCoordinator | None = None
|
|||||||
|
|
||||||
_P_TP: GroupCoordinator | None = None
|
_P_TP: GroupCoordinator | None = None
|
||||||
|
|
||||||
|
_DYNAMIC_EPLB: GroupCoordinator | None = None
|
||||||
|
|
||||||
|
|
||||||
def init_ascend_model_parallel(
|
def init_ascend_model_parallel(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
@@ -85,6 +87,12 @@ def init_ascend_model_parallel(
|
|||||||
|
|
||||||
_MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2")
|
_MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2")
|
||||||
|
|
||||||
|
if get_ascend_config().eplb_config.dynamic_eplb:
|
||||||
|
global _DYNAMIC_EPLB
|
||||||
|
_DYNAMIC_EPLB = init_model_parallel_group(
|
||||||
|
group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb"
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize fine-grained TP process groups on Ascend for four components:
|
# Initialize fine-grained TP process groups on Ascend for four components:
|
||||||
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
|
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
|
||||||
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
|
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
|
||||||
@@ -265,6 +273,11 @@ def get_fc3_quant_x_group() -> GroupCoordinator:
|
|||||||
return _FC3_QUANT_X
|
return _FC3_QUANT_X
|
||||||
|
|
||||||
|
|
||||||
|
def get_dynamic_eplb_group() -> GroupCoordinator:
|
||||||
|
assert _DYNAMIC_EPLB is not None, "fc3 quant x group is not initialized"
|
||||||
|
return _DYNAMIC_EPLB
|
||||||
|
|
||||||
|
|
||||||
def destroy_ascend_model_parallel():
|
def destroy_ascend_model_parallel():
|
||||||
global _MC2
|
global _MC2
|
||||||
if _MC2:
|
if _MC2:
|
||||||
@@ -315,3 +328,8 @@ def destroy_ascend_model_parallel():
|
|||||||
if _FC3_QUANT_X:
|
if _FC3_QUANT_X:
|
||||||
_FC3_QUANT_X.destroy()
|
_FC3_QUANT_X.destroy()
|
||||||
_FC3_QUANT_X = None
|
_FC3_QUANT_X = None
|
||||||
|
|
||||||
|
global _DYNAMIC_EPLB
|
||||||
|
if _DYNAMIC_EPLB:
|
||||||
|
_DYNAMIC_EPLB.destroy()
|
||||||
|
_DYNAMIC_EPLB = None
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import torch.distributed as dist
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
from vllm_ascend.distributed.parallel_state import get_dynamic_eplb_group
|
||||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
|
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
|
||||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||||
@@ -34,6 +35,7 @@ class EplbUpdator:
|
|||||||
self.eplb_process = eplb_process
|
self.eplb_process = eplb_process
|
||||||
self.shared_dict = self.eplb_process.shared_dict
|
self.shared_dict = self.eplb_process.shared_dict
|
||||||
self.moe_imbalance_dict: dict[int, float] = {}
|
self.moe_imbalance_dict: dict[int, float] = {}
|
||||||
|
self.comm_group = get_dynamic_eplb_group()
|
||||||
|
|
||||||
def set_adaptor(self, adaptor: VllmEplbAdaptor):
|
def set_adaptor(self, adaptor: VllmEplbAdaptor):
|
||||||
self.adaptor = adaptor
|
self.adaptor = adaptor
|
||||||
@@ -41,8 +43,6 @@ class EplbUpdator:
|
|||||||
local_load = self.adaptor.get_rank_expert_workload()
|
local_load = self.adaptor.get_rank_expert_workload()
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
self.device = local_load.device
|
self.device = local_load.device
|
||||||
shape = (self.world_size, *local_load.shape)
|
|
||||||
self._gather_buffer = torch.empty(shape, dtype=local_load.dtype, device=self.device)
|
|
||||||
self.eplb_loader.num_layers = self.adaptor.num_dense_layers + self.adaptor.num_moe_layers
|
self.eplb_loader.num_layers = self.adaptor.num_dense_layers + self.adaptor.num_moe_layers
|
||||||
|
|
||||||
def init_eplb(self, expert_map_path, process):
|
def init_eplb(self, expert_map_path, process):
|
||||||
@@ -134,9 +134,8 @@ class EplbUpdator:
|
|||||||
|
|
||||||
def compute_and_set_moe_load(self):
|
def compute_and_set_moe_load(self):
|
||||||
local_load = self.adaptor.get_rank_expert_workload()
|
local_load = self.adaptor.get_rank_expert_workload()
|
||||||
dist.all_gather_into_tensor(self._gather_buffer, local_load)
|
moe_load = self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:])
|
||||||
|
|
||||||
moe_load = self._gather_buffer.permute(1, 0, 2)
|
|
||||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||||
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
|
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
|
||||||
|
|
||||||
@@ -183,17 +182,16 @@ class EplbUpdator:
|
|||||||
self.compute_and_set_moe_load()
|
self.compute_and_set_moe_load()
|
||||||
|
|
||||||
src_tensor = torch.empty((1,), device=self.device)
|
src_tensor = torch.empty((1,), device=self.device)
|
||||||
self_rank = dist.get_rank()
|
|
||||||
|
|
||||||
comm_op_list = []
|
comm_op_list = []
|
||||||
|
|
||||||
for dst_rank in range(self.world_size):
|
for dst_rank in range(self.world_size):
|
||||||
if dst_rank == self_rank:
|
if dst_rank == self.rank_id:
|
||||||
continue
|
continue
|
||||||
comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
|
comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
|
||||||
|
|
||||||
for src_rank in range(self.world_size):
|
for src_rank in range(self.world_size):
|
||||||
if src_rank == self_rank:
|
if src_rank == self.rank_id:
|
||||||
continue
|
continue
|
||||||
comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank))
|
comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank))
|
||||||
if comm_op_list:
|
if comm_op_list:
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ _CP_CHUNKEDPREFILL_COMM_STREAM = None
|
|||||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||||
_DEFAULT_BUFFER_SIZE = 200
|
_DEFAULT_BUFFER_SIZE = 200
|
||||||
_MIN_DP_BUFFER_SIZE = 50
|
_MIN_DP_BUFFER_SIZE = 50
|
||||||
|
_DYNAMIC_EPLB_BUFFER_SIZE = 1 # num_experts * num_layers * 64 byte
|
||||||
_IS_MOE_MODEL = None
|
_IS_MOE_MODEL = None
|
||||||
_IS_DRAFTER_MOE_MODEL = None
|
_IS_DRAFTER_MOE_MODEL = None
|
||||||
_IS_VL_MODEL = None
|
_IS_VL_MODEL = None
|
||||||
@@ -907,6 +908,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> dict | None:
|
|||||||
return None
|
return None
|
||||||
hccl_config_map = {
|
hccl_config_map = {
|
||||||
"dp": {"hccl_buffer_size": calculate_dp_buffer_size()},
|
"dp": {"hccl_buffer_size": calculate_dp_buffer_size()},
|
||||||
|
"dynamic_eplb": {"hccl_buffer_size": _DYNAMIC_EPLB_BUFFER_SIZE},
|
||||||
}
|
}
|
||||||
return hccl_config_map.get(group_name, get_default_buffer_config())
|
return hccl_config_map.get(group_name, get_default_buffer_config())
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user