[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #6) (#6001)

### What this PR does / why we need it?
| File Path |
| :--- |
| ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` |
| ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` |
| ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` |
| ` vllm_ascend/eplb/core/eplb_utils.py` |
| ` vllm_ascend/eplb/core/eplb_worker.py` |
| ` vllm_ascend/eplb/core/policy/policy_abstract.py` |
| ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` |
| ` vllm_ascend/eplb/core/policy/policy_factory.py` |
| ` vllm_ascend/eplb/core/policy/policy_flashlb.py` |
| ` vllm_ascend/eplb/core/policy/policy_random.py` |
| ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` |
| ` vllm_ascend/eplb/eplb_updator.py` |
| ` vllm_ascend/eplb/utils.py` |
| ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` |
| ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` |
| ` vllm_ascend/model_loader/netloader/interaction/elastic.py` |
| ` vllm_ascend/model_loader/netloader/load.py` |
| ` vllm_ascend/model_loader/netloader/netloader.py` |
| ` vllm_ascend/model_loader/netloader/utils.py` |
| ` vllm_ascend/patch/platform/__init__.py` |
| ` vllm_ascend/patch/platform/patch_balance_schedule.py` |
| ` vllm_ascend/patch/platform/patch_ec_connector.py` |
| ` vllm_ascend/patch/platform/patch_mamba_config.py` |
| ` vllm_ascend/patch/platform/patch_multiproc_executor.py` |
| ` vllm_ascend/patch/platform/patch_sched_yield.py` |


- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-24 22:08:33 +08:00
committed by GitHub
parent 153da1a669
commit 4e53c1d900
26 changed files with 894 additions and 1148 deletions

View File

@@ -61,10 +61,6 @@ exclude = [
"vllm_ascend/distributed/kv_transfer/utils/**", "vllm_ascend/distributed/kv_transfer/utils/**",
"vllm_ascend/kv_offload/**", "vllm_ascend/kv_offload/**",
"vllm_ascend/lora/**", "vllm_ascend/lora/**",
# (6)
"vllm_ascend/eplb/**",
"vllm_ascend/model_loader/netloader/**",
"vllm_ascend/patch/**",
# (7) # (7)
"vllm_ascend/quantization/**", "vllm_ascend/quantization/**",
"vllm_ascend/sample/*.py", "vllm_ascend/sample/*.py",
@@ -92,6 +88,7 @@ exclude = [
"vllm_ascend/distributed/parallel_state.py", "vllm_ascend/distributed/parallel_state.py",
"vllm_ascend/distributed/utils.py", "vllm_ascend/distributed/utils.py",
"vllm_ascend/xlite/*.py", "vllm_ascend/xlite/*.py",
"vllm_ascend/patch/worker/patch_*.py",
# (11) # (11)
"vllm_ascend/ops/fused_moe/**", "vllm_ascend/ops/fused_moe/**",
] ]

View File

@@ -19,8 +19,7 @@ from abc import abstractmethod
from typing import Any from typing import Any
class EplbAdaptor(): class EplbAdaptor:
def __init__(self, **args): def __init__(self, **args):
pass pass
@@ -29,12 +28,9 @@ class EplbAdaptor():
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def do_update_expert_map(self, layer_id: Any, def do_update_expert_map(self, layer_id: Any, updated_expert_map: Any) -> Any:
updated_expert_map: Any) -> Any:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def do_update_expert_weight(self, layer_id: Any, def do_update_expert_weight(self, layer_id: Any, local_expert_to_replace: Any, buffer_tensor_id: Any) -> Any:
local_expert_to_replace: Any,
buffer_tensor_id: Any) -> Any:
raise NotImplementedError raise NotImplementedError

View File

@@ -26,7 +26,6 @@ from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
class VllmEplbAdaptor(EplbAdaptor): class VllmEplbAdaptor(EplbAdaptor):
def __init__(self, model, **args): def __init__(self, model, **args):
super().__init__(**args) super().__init__(**args)
self.model = model self.model = model
@@ -36,33 +35,37 @@ class VllmEplbAdaptor(EplbAdaptor):
self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0) self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0)
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
for i in range(self.num_dense_layers, for i in range(self.num_dense_layers, self.model.config.num_hidden_layers):
self.model.config.num_hidden_layers): self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = self.model.model.layers[
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \ i
self.model.model.layers[i].mlp.experts.w13_weight_list ].mlp.experts.w13_weight_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \ self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = self.model.model.layers[
self.model.model.layers[i].mlp.experts.w2_weight_list i
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \ ].mlp.experts.w2_weight_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = (
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ )
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = (
self.model.model.layers[i].mlp.experts.w2_weight_scale_list self.model.model.layers[i].mlp.experts.w2_weight_scale_list
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here )
# TODO: init self.expert_weight_names depending on different model types.
# Only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None: if self.model.quant_config is not None:
self.expert_weight_names = [ self.expert_weight_names = [
"w13_weight_list", "w2_weight_list", "w13_weight_list",
"w13_weight_scale_fp32_list", "w13_weight_offset", "w2_weight_list",
"w2_weight_scale_list", "w2_weight_offset" "w13_weight_scale_fp32_list",
"w13_weight_offset",
"w2_weight_scale_list",
"w2_weight_offset",
] ]
else: else:
self.expert_weight_names = ["w13_weight", "w2_weight"] self.expert_weight_names = ["w13_weight", "w2_weight"]
self.expert_map_per_layer_cpu = dict( self.expert_map_per_layer_cpu = dict() # copy of expert map on CPU to avoid device synchronize frequently
) # copy of expert map on CPU to avoid device synchronize frequently
num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts
self.buffer_tensor_list: list[list[Any]] = [ self.buffer_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)]
[] for _ in range(num_buffer_tensor)
]
self.init_buffer_tensor(num_buffer_tensor) self.init_buffer_tensor(num_buffer_tensor)
self.expert_param_per_layer = dict() self.expert_param_per_layer = dict()
@@ -70,18 +73,15 @@ class VllmEplbAdaptor(EplbAdaptor):
self.log2phy_map_per_layer = dict() self.log2phy_map_per_layer = dict()
for layer_idx in range(self.num_moe_layers): for layer_idx in range(self.num_moe_layers):
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \ self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = self.model.get_log2phy_map(
self.model.get_log2phy_map(self.num_dense_layers + layer_idx) self.num_dense_layers + layer_idx
)
def init_buffer_tensor(self, num_buffer_tensor): def init_buffer_tensor(self, num_buffer_tensor):
for buffer_id in range(num_buffer_tensor): for buffer_id in range(num_buffer_tensor):
for name in self.expert_weight_names: for name in self.expert_weight_names:
complete_name = "model.layers." + str( complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name
self.num_dense_layers) + ".mlp.experts." + name if name in ["w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w2_weight_scale_list"]:
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
]:
expert_tensor = self.param_dict[complete_name][0] expert_tensor = self.param_dict[complete_name][0]
expert_tensor = expert_tensor.clone() expert_tensor = expert_tensor.clone()
else: else:
@@ -99,19 +99,20 @@ class VllmEplbAdaptor(EplbAdaptor):
per_expert_param = list() per_expert_param = list()
for name in self.expert_weight_names: for name in self.expert_weight_names:
if name in [ if name in [
"w13_weight_list", "w2_weight_list", "w13_weight_list",
"w13_weight_scale_fp32_list", "w2_weight_list",
"w2_weight_scale_list" "w13_weight_scale_fp32_list",
"w2_weight_scale_list",
]: ]:
per_expert_param.append( per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) + self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][local_expert_id]
".mlp.experts." + )
name][local_expert_id])
else: else:
per_expert_param.append( per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) + self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][0].data[
".mlp.experts." + local_expert_id
name][0].data[local_expert_id]) ]
)
self.expert_param_per_layer[layer_idx].append(per_expert_param) self.expert_param_per_layer[layer_idx].append(per_expert_param)
def get_rank_expert_workload(self) -> torch.Tensor: def get_rank_expert_workload(self) -> torch.Tensor:
@@ -123,26 +124,18 @@ class VllmEplbAdaptor(EplbAdaptor):
num_local_experts = expert_maps.max() + 1 num_local_experts = expert_maps.max() + 1
expert_maps_list = expert_maps.tolist() expert_maps_list = expert_maps.tolist()
record: dict[str, Any] = { record: dict[str, Any] = {"moe_layer_count": len(expert_maps_list), "layer_list": []}
"moe_layer_count": len(expert_maps_list),
"layer_list": []
}
for layer_idx, layer_data in enumerate(expert_maps_list): for layer_idx, layer_data in enumerate(expert_maps_list):
layer_record: dict[str, Any] = { layer_record: dict[str, Any] = {
"layer_id": layer_idx, "layer_id": layer_idx,
"device_count": len(layer_data), "device_count": len(layer_data),
"device_list": [] "device_list": [],
} }
for device_idx, experts in enumerate(layer_data): for device_idx, experts in enumerate(layer_data):
placement = [ placement = [experts.index(i) for i in range(num_local_experts)]
experts.index(i) for i in range(num_local_experts) device_record = {"device_id": device_idx, "device_expert": placement}
]
device_record = {
"device_id": device_idx,
"device_expert": placement
}
layer_record["device_list"].append(device_record) layer_record["device_list"].append(device_record)
record["layer_list"].append(layer_record) record["layer_list"].append(layer_record)
@@ -153,11 +146,10 @@ class VllmEplbAdaptor(EplbAdaptor):
def do_update_expert_map(self, layer_id, updated_expert_map): def do_update_expert_map(self, layer_id, updated_expert_map):
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
def do_update_expert_weight(self, layer_id, local_expert_to_replace, def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id):
buffer_tensor_id):
for expert_tensor, buffer_tensor in zip( for expert_tensor, buffer_tensor in zip(
self.expert_param_per_layer[layer_id][local_expert_to_replace], self.expert_param_per_layer[layer_id][local_expert_to_replace], self.buffer_tensor_list[buffer_tensor_id]
self.buffer_tensor_list[buffer_tensor_id]): ):
expert_tensor.copy_(buffer_tensor) expert_tensor.copy_(buffer_tensor)
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}") logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
@@ -168,10 +160,8 @@ class VllmEplbAdaptor(EplbAdaptor):
def get_global_expert_map(self): def get_global_expert_map(self):
all_layer_global_expert_map = [] all_layer_global_expert_map = []
for layer_id in range(self.num_moe_layers): for layer_id in range(self.num_moe_layers):
map_cpu = self.model.model.layers[ map_cpu = self.model.model.layers[self.num_dense_layers + layer_id].mlp.experts.global_expert_map.cpu()
self.num_dense_layers + layer_id].mlp.experts.global_expert_map.cpu()
all_layer_global_expert_map.append(map_cpu) all_layer_global_expert_map.append(map_cpu)
self.expert_map_per_layer_cpu[self.num_dense_layers + self.expert_map_per_layer_cpu[self.num_dense_layers + layer_id] = map_cpu[self.rank_id]
layer_id] = map_cpu[self.rank_id]
return torch.stack(all_layer_global_expert_map) return torch.stack(all_layer_global_expert_map)

View File

@@ -27,7 +27,6 @@ class ExpertWeightUpdateState(Enum):
class D2DExpertWeightLoader: class D2DExpertWeightLoader:
def __init__(self): def __init__(self):
self.comm_op_list = None self.comm_op_list = None
self.updated_expert_map = None self.updated_expert_map = None
@@ -40,14 +39,10 @@ class D2DExpertWeightLoader:
def set_adator(self, eplb_adaptor): def set_adator(self, eplb_adaptor):
self.eplb_adaptor = eplb_adaptor self.eplb_adaptor = eplb_adaptor
def generate_expert_d2d_transfer_task(self, expert_send_info, def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info, updated_expert_map, layer_id):
expert_recv_info, updated_expert_map,
layer_id):
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task # When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
if self.state != ExpertWeightUpdateState.WAITING: if self.state != ExpertWeightUpdateState.WAITING:
logger.warning_once( logger.warning_once("current d2d weight update tasks are on-going, cannot accept new weight update task")
"current d2d weight update tasks are on-going, cannot accept new weight update task"
)
return return
self.updated_expert_map = updated_expert_map self.updated_expert_map = updated_expert_map
@@ -56,25 +51,16 @@ class D2DExpertWeightLoader:
self.comm_op_list = [] self.comm_op_list = []
for send_info in expert_send_info: for send_info in expert_send_info:
dst_rank, global_expert_id_to_send = send_info dst_rank, global_expert_id_to_send = send_info
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[ local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item()
layer_id][global_expert_id_to_send].item() for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]:
for src_tensor in self.eplb_adaptor.expert_param_per_layer[ self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
layer_id][local_expert_id]:
self.comm_op_list.append(
dist.P2POp(dist.isend, src_tensor, dst_rank))
buffer_tensor_id = 0 for buffer_tensor_id, recv_info in enumerate(expert_recv_info):
for recv_info in expert_recv_info:
recv_rank, global_expert_id_to_recv = recv_info recv_rank, global_expert_id_to_recv = recv_info
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[ for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]:
buffer_tensor_id]: self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
self.comm_op_list.append( local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item()
dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id))
local_expert_to_replace = self.updated_expert_map[
global_expert_id_to_recv].item()
self.recv_expert_list.append(
(local_expert_to_replace, buffer_tensor_id))
buffer_tensor_id += 1
self.state = ExpertWeightUpdateState.READY self.state = ExpertWeightUpdateState.READY
@@ -106,23 +92,18 @@ class D2DExpertWeightLoader:
self.comm_op_list = None self.comm_op_list = None
# update expert_map # update expert_map
self.eplb_adaptor.do_update_expert_map(self.layer_id, self.eplb_adaptor.do_update_expert_map(self.layer_id, self.updated_expert_map)
self.updated_expert_map)
# update log2phy_map # update log2phy_map
self.eplb_adaptor.do_update_log2phy_map(self.layer_id, self.eplb_adaptor.do_update_log2phy_map(self.layer_id, self.updated_log2phy_map)
self.updated_log2phy_map)
# update expert weight # update expert weight
buffer_tensor_id = 0 buffer_tensor_id = 0
for recv_expert_info in self.recv_expert_list: for recv_expert_info in self.recv_expert_list:
local_expert_to_replace, buffer_tensor_id = recv_expert_info local_expert_to_replace, buffer_tensor_id = recv_expert_info
self.eplb_adaptor.do_update_expert_weight(self.layer_id, self.eplb_adaptor.do_update_expert_weight(self.layer_id, local_expert_to_replace, buffer_tensor_id)
local_expert_to_replace,
buffer_tensor_id)
logger.debug( logger.debug(f"[EPLB] finished update expert weight for layer: {self.layer_id}")
f"[EPLB] finished update expert weight for layer: {self.layer_id}")
self.recv_expert_list = [] self.recv_expert_list = []
self.updated_expert_map = None self.updated_expert_map = None

View File

@@ -25,7 +25,7 @@ from vllm.logger import logger
def expert_file_to_tensor(expert_map_path, layer_id): def expert_file_to_tensor(expert_map_path, layer_id):
with open(expert_map_path, "r") as f: with open(expert_map_path) as f:
data = json.load(f) data = json.load(f)
physical_count = 0 physical_count = 0
device_data = [] device_data = []
@@ -61,38 +61,32 @@ def init_eplb_config(eplb_config, layer_id, moe_config):
eplb_enable = eplb_config.dynamic_eplb eplb_enable = eplb_config.dynamic_eplb
n_redundant = eplb_config.num_redundant_experts if eplb_enable else 0 n_redundant = eplb_config.num_redundant_experts if eplb_enable else 0
if expert_map_path: if expert_map_path:
if not (os.path.exists(expert_map_path) if not (os.path.exists(expert_map_path) and os.access(expert_map_path, os.R_OK)):
and os.access(expert_map_path, os.R_OK)):
raise ValueError("Invalid EPLB path") raise ValueError("Invalid EPLB path")
eplb_enable = True eplb_enable = True
global_placement, physical_count = expert_file_to_tensor( global_placement, physical_count = expert_file_to_tensor(expert_map_path, layer_id)
expert_map_path, layer_id)
if physical_count is not None: if physical_count is not None:
n_redundant = physical_count - n_experts n_redundant = physical_count - n_experts
if not moe_config.supports_eplb: if not moe_config.supports_eplb:
raise ValueError( raise ValueError("Eplb supports only w8a8_dynamic quantization.")
"Eplb supports only w8a8_dynamic quantization.")
else: else:
eplb_enable = False eplb_enable = False
if global_placement is None: if global_placement is None:
global_placement = generate_global_placement(n_experts, ep_size, global_placement = generate_global_placement(n_experts, ep_size, n_redundant)
n_redundant)
if ep_size == 1: if ep_size == 1:
assert not eplb_enable, "EPLB must used in expert parallelism." assert not eplb_enable, "EPLB must used in expert parallelism."
return None, None, None, n_redundant return None, None, None, n_redundant
global_expert_map = [] global_expert_map = []
for rankid in range(ep_size): for rankid in range(ep_size):
expert_map = torch.full((n_experts, ), -1, dtype=torch.int32) expert_map = torch.full((n_experts,), -1, dtype=torch.int32)
local_placement = global_placement[rankid] local_placement = global_placement[rankid]
expert_map[local_placement] = torch.arange(local_placement.shape[0], expert_map[local_placement] = torch.arange(local_placement.shape[0], dtype=torch.int32)
dtype=torch.int32)
global_expert_map.append(expert_map) global_expert_map.append(expert_map)
if rankid == moe_config.ep_rank: if rankid == moe_config.ep_rank:
local_expert_map = expert_map.npu() local_expert_map = expert_map.npu()
log2phy = generate_log2phy_map( log2phy = generate_log2phy_map(global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None
global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None
return torch.stack(global_expert_map), local_expert_map, log2phy, n_redundant return torch.stack(global_expert_map), local_expert_map, log2phy, n_redundant
@@ -106,13 +100,15 @@ def generate_log2phy_map(global_expert_map, ep_rank):
if val != -1: if val != -1:
log2phy_map[idx].append(val + rankid * valid_count) log2phy_map[idx].append(val + rankid * valid_count)
for key in log2phy_map.keys(): for key in log2phy_map:
num_of_duplications = len(log2phy_map[key]) num_of_duplications = len(log2phy_map[key])
log2phy_map[key] = log2phy_map[key][ep_rank % num_of_duplications] log2phy_map[key] = log2phy_map[key][ep_rank % num_of_duplications]
log2phy_map = torch.scatter( log2phy_map = torch.scatter(
torch.zeros(len(log2phy_map.keys()), dtype=torch.int32), 0, torch.zeros(len(log2phy_map), dtype=torch.int32),
torch.tensor(list(log2phy_map.keys()), dtype=torch.int64), 0,
torch.tensor(list(log2phy_map.values()), dtype=torch.int32)) torch.tensor(list(log2phy_map), dtype=torch.int64),
torch.tensor(list(log2phy_map.values()), dtype=torch.int32),
)
return log2phy_map return log2phy_map

View File

@@ -17,23 +17,18 @@
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from typing import Any from typing import Any
import networkx as nx # type: ignore
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.logger import logger from vllm.logger import logger
from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map
from vllm_ascend.eplb.core.policy.policy_factory import (DynamicConfig, from vllm_ascend.eplb.core.policy.policy_factory import DynamicConfig, PolicyFactory
PolicyFactory)
class EplbWorker: class EplbWorker:
def __init__(self, shared_dict, policy_type, enable_d2d: bool = True): def __init__(self, shared_dict, policy_type, enable_d2d: bool = True):
self.policy_type = policy_type self.policy_type = policy_type
self.policy = PolicyFactory.generate_policy(policy_type, self.policy = PolicyFactory.generate_policy(policy_type, DynamicConfig())
DynamicConfig())
self.shared_dict = shared_dict self.shared_dict = shared_dict
self.old_expert_maps = None self.old_expert_maps = None
self.enable_d2d = enable_d2d self.enable_d2d = enable_d2d
@@ -62,10 +57,8 @@ class EplbWorker:
return return
# Get the updated expert table based on the workload information # Get the updated expert table based on the workload information
old_placement = self.global2local(self.old_expert_maps, old_placement = self.global2local(self.old_expert_maps, self.num_local_experts)
self.num_local_experts) _, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement)
_, _, new_placement = self.calculate_rebalance_experts(
load_info, old_placement)
if not torch.is_tensor(new_placement): if not torch.is_tensor(new_placement):
new_placement = torch.tensor(new_placement) new_placement = torch.tensor(new_placement)
@@ -73,8 +66,7 @@ class EplbWorker:
new_expert_maps = self.local2global(new_placement) new_expert_maps = self.local2global(new_placement)
self.update_expert_map(new_expert_maps) self.update_expert_map(new_expert_maps)
update_info = self.compose_expert_update_info_greedy( update_info = self.compose_expert_update_info_greedy(new_expert_maps, self.old_expert_maps)
new_expert_maps, self.old_expert_maps)
self.old_expert_maps = new_expert_maps self.old_expert_maps = new_expert_maps
logger.info("EPLB Process compute complete") logger.info("EPLB Process compute complete")
@@ -88,11 +80,8 @@ class EplbWorker:
for layer_id in range(num_layers): for layer_id in range(num_layers):
# check if any logical expert is not placed on any rank # check if any logical expert is not placed on any rank
if torch.unique(new_placement[layer_id]).numel() < torch.unique( if torch.unique(new_placement[layer_id]).numel() < torch.unique(old_placement[layer_id]).numel():
old_placement[layer_id]).numel(): logger.error(f"There exists expert not placed on any rank in layer {layer_id}")
logger.error(
f"There exists expert not placed on any rank in layer {layer_id}"
)
new_placement[layer_id] = old_placement[layer_id] new_placement[layer_id] = old_placement[layer_id]
continue continue
@@ -101,28 +90,26 @@ class EplbWorker:
old_placement_check = old_placement[layer_id][rank_id] old_placement_check = old_placement[layer_id][rank_id]
# check if same logical experts are placed on the same NPU # check if same logical experts are placed on the same NPU
if new_placement_check.numel() != torch.unique( if new_placement_check.numel() != torch.unique(new_placement_check).numel():
new_placement_check).numel():
logger.error( logger.error(
f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid" "Replicated experts are placed on the same NPU; expert placement on "
f"layer {layer_id}, rank {rank_id} is invalid"
) )
new_placement[layer_id] = old_placement[layer_id] new_placement[layer_id] = old_placement[layer_id]
break break
# check if there is any experts movement inside one NPU # check if there is any experts movement inside one NPU
expert_not_move = torch.isin(new_placement_check, expert_not_move = torch.isin(new_placement_check, old_placement_check)
old_placement_check) if not torch.equal(new_placement_check[expert_not_move], old_placement_check[expert_not_move]):
if not torch.equal(new_placement_check[expert_not_move],
old_placement_check[expert_not_move]):
logger.error( logger.error(
f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid" "There exists expert movement inside NPU; expert placement on "
f"layer {layer_id}, rank {rank_id} is invalid"
) )
new_placement[layer_id] = old_placement[layer_id] new_placement[layer_id] = old_placement[layer_id]
break break
# TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases # TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases
def compose_expert_update_info_greedy(self, updated_expert_maps, def compose_expert_update_info_greedy(self, updated_expert_maps, current_expert_maps):
current_expert_maps):
num_layers = current_expert_maps.shape[0] num_layers = current_expert_maps.shape[0]
for layer_id in range(num_layers): for layer_id in range(num_layers):
updated_expert_maps_this_layer = updated_expert_maps[layer_id] updated_expert_maps_this_layer = updated_expert_maps[layer_id]
@@ -132,19 +119,23 @@ class EplbWorker:
expert_recv_info_this_layer: dict[Any, Any] = {} expert_recv_info_this_layer: dict[Any, Any] = {}
# Guard Clause: if there is no expert weight update, avoid subsequent processing # Guard Clause: if there is no expert weight update, avoid subsequent processing
if torch.equal(updated_expert_maps_this_layer, if torch.equal(updated_expert_maps_this_layer, current_expert_maps_this_layer):
current_expert_maps_this_layer): yield (
yield (expert_send_info_this_layer, expert_send_info_this_layer,
expert_recv_info_this_layer, expert_recv_info_this_layer,
updated_expert_maps_this_layer, layer_id) updated_expert_maps_this_layer,
layer_id,
)
# Parse expert_ids each rank needs to receive from other ranks # Parse expert_ids each rank needs to receive from other ranks
dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \ dst_rank_indices, experts_to_recv = torch.where(
& (updated_expert_maps_this_layer != -1)) (current_expert_maps_this_layer == -1) & (updated_expert_maps_this_layer != -1)
)
# Parse expert_ids each rank needs to send to other ranks # Parse expert_ids each rank needs to send to other ranks
src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \ src_rank_indices, experts_to_send = torch.where(
& (updated_expert_maps_this_layer == -1)) (current_expert_maps_this_layer != -1) & (updated_expert_maps_this_layer == -1)
)
for idx in range(len(dst_rank_indices)): for idx in range(len(dst_rank_indices)):
dst_rank_id = dst_rank_indices[idx].item() dst_rank_id = dst_rank_indices[idx].item()
@@ -152,27 +143,27 @@ class EplbWorker:
if dst_rank_id not in expert_recv_info_this_layer: if dst_rank_id not in expert_recv_info_this_layer:
expert_recv_info_this_layer[dst_rank_id] = [] expert_recv_info_this_layer[dst_rank_id] = []
if not torch.isin(torch.tensor(expert_id), if not torch.isin(torch.tensor(expert_id), experts_to_send).any():
experts_to_send).any():
# if expert_id are not sent out from any npu, it will be copied from one npu holding this expert # if expert_id are not sent out from any npu, it will be copied from one npu holding this expert
candidate_src_rank_indices = torch.where( candidate_src_rank_indices = torch.where(current_expert_maps_this_layer[:, expert_id] != -1)[0]
current_expert_maps_this_layer[:, expert_id] != -1)[0]
else: else:
candidate_src_rank_indices = src_rank_indices[ candidate_src_rank_indices = src_rank_indices[experts_to_send == expert_id]
experts_to_send == expert_id]
# TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node... # TODO: improve selection criterion of NPU sending expert_id,
# considering intra-node or inter-node...
src_rank_id = candidate_src_rank_indices[0].item() src_rank_id = candidate_src_rank_indices[0].item()
if src_rank_id not in expert_send_info_this_layer: if src_rank_id not in expert_send_info_this_layer:
expert_send_info_this_layer[src_rank_id] = [] expert_send_info_this_layer[src_rank_id] = []
expert_send_info_this_layer[src_rank_id].append( expert_send_info_this_layer[src_rank_id].append((dst_rank_id, expert_id))
(dst_rank_id, expert_id)) expert_recv_info_this_layer[dst_rank_id].append((src_rank_id, expert_id))
expert_recv_info_this_layer[dst_rank_id].append(
(src_rank_id, expert_id))
yield (expert_send_info_this_layer, expert_recv_info_this_layer, yield (
updated_expert_maps_this_layer, layer_id) expert_send_info_this_layer,
expert_recv_info_this_layer,
updated_expert_maps_this_layer,
layer_id,
)
def calculate_rebalance_experts(self, load_info, old_placement): def calculate_rebalance_experts(self, load_info, old_placement):
""" """
@@ -181,8 +172,7 @@ class EplbWorker:
if self.old_expert_maps is None: if self.old_expert_maps is None:
return False, None, None return False, None, None
changed, priority, new_map = self.policy.rebalance_experts( changed, priority, new_map = self.policy.rebalance_experts(old_placement, load_info)
old_placement, load_info)
return changed, priority, new_map return changed, priority, new_map
def get_init_expert_maps(self): def get_init_expert_maps(self):
@@ -199,19 +189,13 @@ class EplbWorker:
return self.shared_dict.get("moe_load", None) return self.shared_dict.get("moe_load", None)
def update_expert_map(self, expert_maps): def update_expert_map(self, expert_maps):
self.shared_dict["expert_maps"] = expert_maps self.shared_dict["expert_maps"] = expert_maps
def global2local(self, placement: torch.Tensor, def global2local(self, placement: torch.Tensor, E_local: int) -> tuple[torch.Tensor, torch.Tensor]:
E_local: int) -> tuple[torch.Tensor, torch.Tensor]:
L, G, _ = placement.shape L, G, _ = placement.shape
device = placement.device device = placement.device
pt_local = torch.full((L, G, E_local), pt_local = torch.full((L, G, E_local), fill_value=-1, dtype=torch.long, device=device)
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement >= 0 valid = placement >= 0
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
@@ -223,7 +207,6 @@ class EplbWorker:
return pt_local return pt_local
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor: def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
L, G, E_local = placement_local.shape L, G, E_local = placement_local.shape
device = placement_local.device device = placement_local.device
@@ -233,10 +216,7 @@ class EplbWorker:
if E_global == 0: if E_global == 0:
return torch.empty((L, G, 0), dtype=torch.long, device=device) return torch.empty((L, G, 0), dtype=torch.long, device=device)
placement_global = torch.full((L, G, E_global), placement_global = torch.full((L, G, E_global), fill_value=-1, dtype=torch.long, device=device)
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement_local >= 0 valid = placement_local >= 0
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
@@ -257,11 +237,8 @@ class EplbWorker:
layer_ids = [] layer_ids = []
for send_info, recv_info, new_expert_map, layer_id in update_info_generator: for send_info, recv_info, new_expert_map, layer_id in update_info_generator:
send_info_this_rank = send_info.get(self.rank_id, [])
send_info_this_rank = send_info[ recv_info_this_rank = recv_info.get(self.rank_id, [])
self.rank_id] if self.rank_id in send_info else []
recv_info_this_rank = recv_info[
self.rank_id] if self.rank_id in recv_info else []
send_all.append(send_info_this_rank) send_all.append(send_info_this_rank)
recv_all.append(recv_info_this_rank) recv_all.append(recv_info_this_rank)
@@ -276,11 +253,7 @@ class EplbWorker:
class EplbProcess: class EplbProcess:
def __init__(self, shared_dict, policy_type: int = 0, enable_d2d: bool = True):
def __init__(self,
shared_dict,
policy_type: int = 0,
enable_d2d: bool = True):
""" """
Args: Args:
shared_dict: Cross-process shared dict returned by Manager().dict() shared_dict: Cross-process shared dict returned by Manager().dict()
@@ -294,12 +267,12 @@ class EplbProcess:
self.block_update_q: Queue[Any] = Queue(maxsize=1) self.block_update_q: Queue[Any] = Queue(maxsize=1)
# Create EplbWorker instance # Create EplbWorker instance
self.worker = EplbWorker(self.shared_dict, self.policy_type, self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d)
self.enable_d2d)
def worker_process(self, planner_q, block_update_q): def worker_process(self, planner_q, block_update_q):
""" """
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up,
call do_update, then notify main process update is complete.
""" """
while True: while True:
try: try:
@@ -314,17 +287,17 @@ class EplbProcess:
break break
except Exception as e: except Exception as e:
logger.warning(f"[EPLB subprocess Exiting due to error: {e}", logger.warning(
exc_info=True) f"[EPLB subprocess exiting due to error: {e}]",
exc_info=True,
)
break break
def _launch_process(self): def _launch_process(self):
""" """
Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). Use spawn method to launch subprocess and return (planner_q, block_update_q, proc).
""" """
proc = Process(target=self.worker_process, proc = Process(target=self.worker_process, args=(self.planner_q, self.block_update_q), daemon=True)
args=(self.planner_q, self.block_update_q),
daemon=True)
proc.start() proc.start()
return proc return proc

View File

@@ -12,7 +12,6 @@ class DynamicConfig:
class EplbPolicy: class EplbPolicy:
def __init__(self, config: DynamicConfig): def __init__(self, config: DynamicConfig):
self.config = config self.config = config

View File

@@ -25,13 +25,11 @@ class DynamicTable:
class DefaultEplb(EplbPolicy): class DefaultEplb(EplbPolicy):
def __init__(self, config: DynamicConfig): def __init__(self, config: DynamicConfig):
super().__init__(config) super().__init__(config)
@staticmethod @staticmethod
def add_redundant(current_expert_table, expert_workload, def add_redundant(current_expert_table, expert_workload, num_original_expert):
num_original_expert):
layer_num, npu_num, experts_per_npu = expert_workload.shape layer_num, npu_num, experts_per_npu = expert_workload.shape
workload_new = np.zeros((layer_num, num_original_expert)) workload_new = np.zeros((layer_num, num_original_expert))
for layer_idx in range(layer_num): for layer_idx in range(layer_num):
@@ -40,31 +38,24 @@ class DefaultEplb(EplbPolicy):
workload_layer = expert_workload[layer_idx].copy() workload_layer = expert_workload[layer_idx].copy()
for npu_idx in range(npu_num): for npu_idx in range(npu_num):
for expert_idx in range(experts_per_npu): for expert_idx in range(experts_per_npu):
workload_dict[placement_layer[npu_idx][ workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx]
expert_idx]] += workload_layer[npu_idx][expert_idx]
for expert_idx in range(num_original_expert): for expert_idx in range(num_original_expert):
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
return workload_new return workload_new
@staticmethod @staticmethod
# Split hot (high-load) experts into redundant experts # Split hot (high-load) experts into redundant experts
def original_compute_balanced_pack_redundancy(origin_weights, card_num, def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert):
num_redundancy_expert):
# Step 1: Sort the items by weight in descending order (we are sorting by weight now) # Step 1: Sort the items by weight in descending order (we are sorting by weight now)
# Sort based on the second element (the second value of each tuple) # Sort based on the second element (the second value of each tuple)
route_expert_num = len(origin_weights) route_expert_num = len(origin_weights)
route_expert_redundancy: list[list[int]] = [ route_expert_redundancy: list[list[int]] = [[] for _ in range(route_expert_num)]
[] for _ in range(route_expert_num)
]
for i in range(num_redundancy_expert): for i in range(num_redundancy_expert):
sorted_indices = np.argsort([t[1] for t in origin_weights], sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1]
kind='stable')[::-1]
weights = [origin_weights[idx] for idx in sorted_indices] weights = [origin_weights[idx] for idx in sorted_indices]
tmp_raw_weight = weights[0][1] * ( tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1)
len(route_expert_redundancy[weights[0][0]]) + 1)
route_expert_redundancy[weights[0][0]].append(route_expert_num + i) route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
avg_weight = tmp_raw_weight / ( avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1)
len(route_expert_redundancy[weights[0][0]]) + 1)
weights[0] = (weights[0][0], avg_weight) weights[0] = (weights[0][0], avg_weight)
origin_weights = weights origin_weights = weights
@@ -93,8 +84,7 @@ class DefaultEplb(EplbPolicy):
box_counts[index] += 1 box_counts[index] += 1
index += 1 index += 1
sorted_indices = np.argsort([t[1] for t in origin_weights], sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1]
kind='stable')[::-1]
origin_weights = [origin_weights[idx] for idx in sorted_indices] origin_weights = [origin_weights[idx] for idx in sorted_indices]
# Step 4: Distribute items into boxes based on weight # Step 4: Distribute items into boxes based on weight
for item_id, weight in origin_weights: for item_id, weight in origin_weights:
@@ -104,11 +94,8 @@ class DefaultEplb(EplbPolicy):
if item_id in boxes[i]: if item_id in boxes[i]:
continue continue
# Only choose boxes that still have space (box_counts[i] < items_per_box) # Only choose boxes that still have space (box_counts[i] < items_per_box)
if box_counts[i] < items_per_box or (box_counts[i] if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0):
== items_per_box if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]:
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
min_box_index = i min_box_index = i
# Place the item (id) into the selected box # Place the item (id) into the selected box
@@ -118,40 +105,35 @@ class DefaultEplb(EplbPolicy):
box_counts[min_box_index] += 1 box_counts[min_box_index] += 1
# If there's an imbalance in the remaining items, reduce the "remaining_items" counter # If there's an imbalance in the remaining items, reduce the "remaining_items" counter
if box_counts[min_box_index] == (items_per_box + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0:
1) and remaining_items > 0:
remaining_items -= 1 remaining_items -= 1
# Step 5: Output each box's contents and total weight # Step 5: Output each box's contents and total weight
result = [] result = []
for i in range(card_num): for i in range(card_num):
result.append({ result.append(
"box_index": i + 1, {
"items": boxes[i], # List of item IDs in the box "box_index": i + 1,
"weight": boxes_weights[i], "items": boxes[i], # List of item IDs in the box
"total_weight": box_weights[i], # Total weight in this box "weight": boxes_weights[i],
"item_count": box_counts[i] # Number of items in the box "total_weight": box_weights[i], # Total weight in this box
}) "item_count": box_counts[i], # Number of items in the box
}
)
return result, boxes return result, boxes
# Split hot (high-load) experts into redundant experts # Split hot (high-load) experts into redundant experts
@staticmethod @staticmethod
def compute_balanced_pack_redundancy(origin_weights, card_num, def compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert):
num_redundancy_expert):
route_expert_num = len(origin_weights) route_expert_num = len(origin_weights)
route_expert_redundancy: list[list[int]] = [ route_expert_redundancy: list[list[int]] = [[] for _ in range(route_expert_num)]
[] for _ in range(route_expert_num)
]
for i in range(num_redundancy_expert): for i in range(num_redundancy_expert):
sorted_indices = np.argsort([t[1] for t in origin_weights], sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1]
kind='stable')[::-1]
weights = [origin_weights[idx] for idx in sorted_indices] weights = [origin_weights[idx] for idx in sorted_indices]
tmp_raw_weight = weights[0][1] * ( tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1)
len(route_expert_redundancy[weights[0][0]]) + 1)
route_expert_redundancy[weights[0][0]].append(route_expert_num + i) route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
avg_weight = tmp_raw_weight / ( avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1)
len(route_expert_redundancy[weights[0][0]]) + 1)
weights[0] = (weights[0][0], avg_weight) weights[0] = (weights[0][0], avg_weight)
origin_weights = weights origin_weights = weights
@@ -166,7 +148,7 @@ class DefaultEplb(EplbPolicy):
box_weights = [0] * card_num box_weights = [0] * card_num
box_counts = [0] * card_num box_counts = [0] * card_num
all_weights = np.zeros((expert_num, ), dtype='object') all_weights = np.zeros((expert_num,), dtype="object")
all_weights[:route_expert_num] = origin_weights all_weights[:route_expert_num] = origin_weights
index = route_expert_num index = route_expert_num
@@ -178,17 +160,13 @@ class DefaultEplb(EplbPolicy):
all_weights[index] = (item, weight) all_weights[index] = (item, weight)
index += 1 index += 1
sorted_indices = np.argsort([t[1] for t in all_weights], sorted_indices = np.argsort([t[1] for t in all_weights], kind="stable")[::-1]
kind='stable')[::-1]
all_weights = [all_weights[idx] for idx in sorted_indices] all_weights = [all_weights[idx] for idx in sorted_indices]
for item_id, weight in all_weights: for item_id, weight in all_weights:
min_box_index = -1 min_box_index = -1
for i in range(card_num): for i in range(card_num):
if box_counts[i] < items_per_box or (box_counts[i] if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0):
== items_per_box if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]:
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
if item_id not in boxes[i]: if item_id not in boxes[i]:
min_box_index = i min_box_index = i
@@ -197,19 +175,20 @@ class DefaultEplb(EplbPolicy):
box_weights[min_box_index] += weight box_weights[min_box_index] += weight
box_counts[min_box_index] += 1 box_counts[min_box_index] += 1
if box_counts[min_box_index] == (items_per_box + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0:
1) and remaining_items > 0:
remaining_items -= 1 remaining_items -= 1
result = [] result = []
for i in range(card_num): for i in range(card_num):
result.append({ result.append(
"box_index": i + 1, {
"items": boxes[i], "box_index": i + 1,
"weight": boxes_weights[i], "items": boxes[i],
"total_weight": box_weights[i], "weight": boxes_weights[i],
"item_count": box_counts[i] "total_weight": box_weights[i],
}) "item_count": box_counts[i],
}
)
return result, boxes return result, boxes
@@ -232,11 +211,8 @@ class DefaultEplb(EplbPolicy):
for item_id, weight in weights: for item_id, weight in weights:
min_box_index = -1 min_box_index = -1
for i in range(card_num): for i in range(card_num):
if box_counts[i] < items_per_box or (box_counts[i] if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0):
== items_per_box if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]:
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
min_box_index = i min_box_index = i
boxes[min_box_index].append(item_id) boxes[min_box_index].append(item_id)
@@ -244,19 +220,20 @@ class DefaultEplb(EplbPolicy):
box_weights[min_box_index] += weight box_weights[min_box_index] += weight
box_counts[min_box_index] += 1 box_counts[min_box_index] += 1
if box_counts[min_box_index] == (items_per_box + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0:
1) and remaining_items > 0:
remaining_items -= 1 remaining_items -= 1
result = [] result = []
for i in range(card_num): for i in range(card_num):
result.append({ result.append(
"box_index": i + 1, {
"items": boxes[i], "box_index": i + 1,
"weight": boxes_weights[i], "items": boxes[i],
"total_weight": box_weights[i], "weight": boxes_weights[i],
"item_count": box_counts[i] "total_weight": box_weights[i],
}) "item_count": box_counts[i],
}
)
return result, boxes return result, boxes
@@ -274,16 +251,11 @@ class DefaultEplb(EplbPolicy):
return max_heat_per_layer return max_heat_per_layer
@staticmethod @staticmethod
def constraint_expert_local_exchange(current_expert_table, def constraint_expert_local_exchange(current_expert_table, global_deployment):
global_deployment):
for layer_id in range(len(global_deployment)): for layer_id in range(len(global_deployment)):
for card_id in range(len(global_deployment[layer_id])): for card_id in range(len(global_deployment[layer_id])):
current_list = [ current_list = [int(x) for x in current_expert_table[layer_id][card_id]]
int(x) for x in current_expert_table[layer_id][card_id] new_list = [int(x) for x in global_deployment[layer_id][card_id]]
]
new_list = [
int(x) for x in global_deployment[layer_id][card_id]
]
num = len(new_list) num = len(new_list)
new_index = [-1] * num new_index = [-1] * num
@@ -293,8 +265,7 @@ class DefaultEplb(EplbPolicy):
for i in range(num): for i in range(num):
flag = True flag = True
for j in range(num): for j in range(num):
if new_list[i] == current_list[j] and new_index[ if new_list[i] == current_list[j] and new_index[j] == -1:
j] == -1:
new_index[j] = 0 new_index[j] = 0
new_result[j] = current_list[j] new_result[j] = current_list[j]
flag = False flag = False
@@ -313,7 +284,6 @@ class DefaultEplb(EplbPolicy):
return global_deployment return global_deployment
def rebalance_experts(self, current_expert_table, expert_workload): def rebalance_experts(self, current_expert_table, expert_workload):
info = DynamicTable() info = DynamicTable()
info.workload_table = np.array(expert_workload) info.workload_table = np.array(expert_workload)
info.placement_table = np.array(current_expert_table) info.placement_table = np.array(current_expert_table)
@@ -324,17 +294,15 @@ class DefaultEplb(EplbPolicy):
expert_ids, counts = np.unique(row, return_counts=True) expert_ids, counts = np.unique(row, return_counts=True)
num_redundancy_expert = self.get_redundant_num(num_npus, counts) num_redundancy_expert = self.get_redundant_num(num_npus, counts)
num_original_expert = len(expert_ids) num_original_expert = len(expert_ids)
layer_workloads = self.add_redundant(info.placement_table, layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert)
info.workload_table, max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num)
num_original_expert)
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
info.workload_table, layer_num)
npu_heat_all_origin = sum(max_heat_per_layer_before) npu_heat_all_origin = sum(max_heat_per_layer_before)
# Perform load balancing and deploy redundant experts # Perform load balancing and deploy redundant experts
layer_num = layer_workloads.shape[0] layer_num = layer_workloads.shape[0]
expert_num = layer_workloads.shape[1] expert_num = layer_workloads.shape[1]
# Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards # Validate that the number of experts, number of cards, and number of redundant experts
# do not exceed the number of cards.
if num_original_expert != expert_num: if num_original_expert != expert_num:
raise ValueError( raise ValueError(
f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}" f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}"
@@ -345,38 +313,35 @@ class DefaultEplb(EplbPolicy):
if num_npus < num_redundancy_expert: if num_npus < num_redundancy_expert:
raise ValueError( raise ValueError(
f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}" "the number of NPUs "
f"{num_npus} must be greater than or equal to the number of redundant experts "
f"{num_redundancy_expert}"
) )
# Number of experts deployed on each card includes one redundant expert # Number of experts deployed on each card includes one redundant expert
global_deployment: list[list[list[int]]] = [[[] global_deployment: list[list[list[int]]] = [[[] for _ in range(num_npus)] for _ in range(layer_num)]
for _ in range(num_npus)]
for _ in range(layer_num)]
# Iterate to obtain the placement strategy for each layer, taking computational balance into account # Iterate to obtain the placement strategy for each layer, taking computational balance into account
max_heat_per_layer_after = np.zeros([layer_num]) max_heat_per_layer_after = np.zeros([layer_num])
for layer in range(layer_num): for layer in range(layer_num):
# Get the expert IDs and their corresponding workloads for the current layer; # Get the expert IDs and their corresponding workloads for the current layer;
# workloads need to be normalized, and one redundant expert is added per card # workloads need to be normalized, and one redundant expert is added per card
weights = np.zeros((expert_num, ), dtype='object') weights = np.zeros((expert_num,), dtype="object")
for expert_id, workload_weight in enumerate( for expert_id, workload_weight in enumerate(layer_workloads[layer]):
layer_workloads[layer]):
weights[expert_id] = (expert_id, workload_weight) weights[expert_id] = (expert_id, workload_weight)
# Obtain the globally balanced placement strategy for each layer # Obtain the globally balanced placement strategy for each layer
result, layer_deployment = self.original_compute_balanced_pack_redundancy( result, layer_deployment = self.original_compute_balanced_pack_redundancy(
weights, num_npus, num_redundancy_expert) weights, num_npus, num_redundancy_expert
)
global_deployment[layer] = layer_deployment global_deployment[layer] = layer_deployment
max_heat_per_layer_after[layer] = max( max_heat_per_layer_after[layer] = max(result, key=lambda x: x["total_weight"])["total_weight"]
result, key=lambda x: x['total_weight'])['total_weight']
new_global_deployment = self.constraint_expert_local_exchange( new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment)
current_expert_table, global_deployment)
# Obtain the priority of each layer # Obtain the priority of each layer
layer_changed_ratio = [] layer_changed_ratio = []
for layer_idx in range(layer_num): for layer_idx in range(layer_num):
layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx])
max_heat_per_layer_before[layer_idx])
per_layer_priority = np.argsort(layer_changed_ratio) per_layer_priority = np.argsort(layer_changed_ratio)
npu_heat_all_after = sum(max_heat_per_layer_after) npu_heat_all_after = sum(max_heat_per_layer_after)
@@ -385,5 +350,4 @@ class DefaultEplb(EplbPolicy):
if npu_heat_all_after < 0.95 * npu_heat_all_origin: if npu_heat_all_after < 0.95 * npu_heat_all_origin:
change = 1 change = 1
return change, per_layer_priority, np.array( return change, per_layer_priority, np.array(new_global_deployment).tolist()
new_global_deployment).tolist()

View File

@@ -2,29 +2,26 @@
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory. # Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory.
from .policy_abstract import DynamicConfig, EplbPolicy from .policy_abstract import DynamicConfig, EplbPolicy
from .policy_default_eplb import DefaultEplb from .policy_default_eplb import DefaultEplb
from .policy_swift_balancer import SwiftBalanceEplb
from .policy_flashlb import FlashLB, warm_up from .policy_flashlb import FlashLB, warm_up
from .policy_random import RandomLoadBalance from .policy_random import RandomLoadBalance
from .policy_swift_balancer import SwiftBalanceEplb
class PolicyFactory: class PolicyFactory:
@staticmethod @staticmethod
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy: def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
policy = { policy = {
# Constraint applying Dynamic EPLB policy V2: # Constraint applying Dynamic EPLB policy V2:
# If there exists redundant expert: # If there exists redundant expert:
# only one redundant expert can be placed in one NPU and its physical expert index must be 0 # only one redundant expert can be placed in one NPU and its physical expert index must be 0
# Applying greedy d2d expert weight update composing # Applying greedy d2d expert weight update composing
0: 0: RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3
RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3 1: DefaultEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
1: # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
DefaultEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load 2: SwiftBalanceEplb,
2: # FlashLB EPLB policy: expert replacement based on Joint Optimization,
SwiftBalanceEplb, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle # Multi-Shot Enhancement and Incremental Adjustment
3: 3: FlashLB,
FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment
} }
policy_class = policy.get(policy_type, RandomLoadBalance) policy_class = policy.get(policy_type, RandomLoadBalance)
policy_instance = policy_class(config) policy_instance = policy_class(config)

View File

@@ -3,7 +3,6 @@
import logging import logging
from collections import deque from collections import deque
from typing import Dict
import numpy as np import numpy as np
import torch import torch
@@ -45,8 +44,7 @@ def compute_piece_counts(X, P, stage_weights):
secv = unit[i, idx2] secv = unit[i, idx2]
alt = X[i, idx1] / (pieces[idx1] + 1) alt = X[i, idx1] / (pieces[idx1] + 1)
delta = origin - (alt if alt > secv else secv) delta = origin - (alt if alt > secv else secv)
deltas[idx1] += delta * stage_weights[i] if np.any( deltas[idx1] += delta * stage_weights[i] if np.any(delta) != 0 else stage_weights[i]
delta) != 0 else stage_weights[i]
max_idx = np.argmax(deltas) max_idx = np.argmax(deltas)
pieces[max_idx] += 1 pieces[max_idx] += 1
@@ -157,9 +155,7 @@ def jsq_placement(X, pieces, M, stage_weights):
# Get elements already in current column # Get elements already in current column
current_rank_elements = set(deployment[rank, :]) current_rank_elements = set(deployment[rank, :])
# Filter elements from range(N) not in current column # Filter elements from range(N) not in current column
available = [ available = [x for x in range(N) if x not in current_rank_elements]
x for x in range(N) if x not in current_rank_elements
]
# Randomly select an available element to fill # Randomly select an available element to fill
if len(available) > 0: if len(available) > 0:
rand_idx = np.random.randint(0, len(available)) rand_idx = np.random.randint(0, len(available))
@@ -187,8 +183,7 @@ def slice_values(X, pieces):
@njit @njit
def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, simulated_deployment, stage_weights):
simulated_deployment, stage_weights):
n_stage, N = X.shape n_stage, N = X.shape
num_group = P // M num_group = P // M
@@ -207,12 +202,11 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
flat_deployment = simulated_deployment.reshape(-1) flat_deployment = simulated_deployment.reshape(-1)
simulated_load = np.zeros(M, dtype=np.float32) simulated_load = np.zeros(M, dtype=np.float32)
for i in range(flat_deployment.shape[0]): for i in range(flat_deployment.shape[0]):
simulated_load[i // (flat_deployment.shape[0] // simulated_load[i // (flat_deployment.shape[0] // M)] += unit_load[flat_deployment[i]]
M)] += unit_load[flat_deployment[i]]
slice_vals = slice_values(X_all, simulated_pieces) slice_vals = slice_values(X_all, simulated_pieces)
sorted_slices = np.sort(slice_vals)[::-1] sorted_slices = np.sort(slice_vals)[::-1]
simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M simulated_slopes = (sorted_slices[: -M + 1] - sorted_slices[M - 1 :]) / M
cumulative_slices_used = np.zeros(N, dtype=np.int32) cumulative_slices_used = np.zeros(N, dtype=np.int32)
acc = 0 acc = 0
@@ -230,8 +224,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
slices_used_per_group = np.zeros(num_group, dtype=np.int32) slices_used_per_group = np.zeros(num_group, dtype=np.int32)
slices_used_per_group[0] = group_boundary_indices[0] slices_used_per_group[0] = group_boundary_indices[0]
for i in range(1, num_group): for i in range(1, num_group):
slices_used_per_group[ slices_used_per_group[i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
slices_used_per_group = M - slices_used_per_group slices_used_per_group = M - slices_used_per_group
loads = np.zeros(M, dtype=np.float32) loads = np.zeros(M, dtype=np.float32)
@@ -240,7 +233,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
current_idx = 0 current_idx = 0
for g in range(num_group): for g in range(num_group):
window = X_sorted[:, current_idx:current_idx + 2 * M] window = X_sorted[:, current_idx : current_idx + 2 * M]
low = max(0, current_idx + M - N) low = max(0, current_idx + M - N)
high = min(num_remain_slice, M - 1) high = min(num_remain_slice, M - 1)
@@ -248,8 +241,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
mid = int((high + low) // 2) mid = int((high + low) // 2)
keep = M - mid keep = M - mid
current_group = window[:, :keep] current_group = window[:, :keep]
current_pieces = compute_piece_counts(current_group, M, current_pieces = compute_piece_counts(current_group, M, stage_weights)
stage_weights)
current_pieces = np.maximum(current_pieces, 1) current_pieces = np.maximum(current_pieces, 1)
current_slice = slice_values(current_group.sum(0), current_pieces) current_slice = slice_values(current_group.sum(0), current_pieces)
current_slice_sorted = np.sort(current_slice) current_slice_sorted = np.sort(current_slice)
@@ -257,8 +249,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
current_max: np.float32 = np.max(current_loads) current_max: np.float32 = np.max(current_loads)
current_min: np.float32 = np.min(current_loads) current_min: np.float32 = np.min(current_loads)
current_slope = (current_max - current_min) / M current_slope = (current_max - current_min) / M
next_slope: np.float32 = np.max(simulated_slopes[current_idx + next_slope: np.float32 = np.max(simulated_slopes[current_idx + keep :])
keep:])
if abs(current_slope) > abs(next_slope): if abs(current_slope) > abs(next_slope):
low = mid low = mid
@@ -327,9 +318,7 @@ def auto_fix_new_placement(old_placement, new_placement):
old_row = old_placement[rank_id] old_row = old_placement[rank_id]
new_row = new_placement[rank_id] new_row = new_placement[rank_id]
index_array = np.full((max_expert + 1, num_experts), index_array = np.full((max_expert + 1, num_experts), -1, dtype=np.int32)
-1,
dtype=np.int32)
count_array = np.zeros(max_expert + 1, dtype=np.int32) count_array = np.zeros(max_expert + 1, dtype=np.int32)
for idx in range(num_experts): for idx in range(num_experts):
@@ -387,21 +376,17 @@ def auto_fix_new_placement(old_placement, new_placement):
class FlashLB(EplbPolicy): class FlashLB(EplbPolicy):
def __init__(self, config: DynamicConfig): def __init__(self, config: DynamicConfig):
super().__init__(config) super().__init__(config)
self.par_history: Dict[int, float] = {} self.par_history: dict[int, float] = {}
self.hotness_window: Dict[int, deque[float]] = {} self.hotness_window: dict[int, deque[float]] = {}
self.max_stage_window = (config.max_stage_window if hasattr( self.max_stage_window = config.max_stage_window if hasattr(config, "max_stage_window") else 1
config, "max_stage_window") else 1)
self.buffer_expert_layer_num = ( self.buffer_expert_layer_num = (
config.buffer_expert_layer_num if hasattr( config.buffer_expert_layer_num if hasattr(config, "buffer_expert_layer_num") else 58
config, "buffer_expert_layer_num") else 58) )
self.threshold_ratio = (config.threshold_ratio if hasattr( self.threshold_ratio = config.threshold_ratio if hasattr(config, "threshold_ratio") else 0
config, "threshold_ratio") else 0)
def compute_expert_hotness(self, num_of_expert: int, def compute_expert_hotness(self, num_of_expert: int, deployment: np.ndarray, rank_load: np.ndarray):
deployment: np.ndarray, rank_load: np.ndarray):
hotness = np.zeros(num_of_expert, dtype=rank_load.dtype) hotness = np.zeros(num_of_expert, dtype=rank_load.dtype)
deployment_flat = deployment.ravel() deployment_flat = deployment.ravel()
rank_load_flat = rank_load.ravel() rank_load_flat = rank_load.ravel()
@@ -413,22 +398,14 @@ class FlashLB(EplbPolicy):
if np.any(deployment < 0): if np.any(deployment < 0):
raise ValueError("Deployment table contains negative values.") raise ValueError("Deployment table contains negative values.")
counts = np.bincount(deployment.reshape(-1), minlength=N) counts = np.bincount(deployment.reshape(-1), minlength=N)
unit_hotness = np.divide(hotness, unit_hotness = np.divide(hotness, counts, out=np.zeros_like(hotness, dtype=float), where=counts != 0)
counts,
out=np.zeros_like(hotness, dtype=float),
where=counts != 0)
stage_par = np.zeros(n_stage) stage_par = np.zeros(n_stage)
for i in range(n_stage): for i in range(n_stage):
stage_load = unit_hotness[i][deployment].sum(-1) stage_load = unit_hotness[i][deployment].sum(-1)
stage_par[i] = stage_load.max() / stage_load.mean() stage_par[i] = stage_load.max() / stage_load.mean()
return stage_par.mean() return stage_par.mean()
def group_based_adaptive_bloating(self, def group_based_adaptive_bloating(self, X, P, M, stage_weights=None, recorsive=False):
X,
P,
M,
stage_weights=None,
recorsive=False):
n_stage, N = X.shape n_stage, N = X.shape
if stage_weights is None: if stage_weights is None:
stage_weights = np.ones(n_stage, dtype=np.float32) stage_weights = np.ones(n_stage, dtype=np.float32)
@@ -437,15 +414,10 @@ class FlashLB(EplbPolicy):
( (
simulated_deployment, simulated_deployment,
simulated_pieces, simulated_pieces,
) = self.group_based_adaptive_bloating(X, ) = self.group_based_adaptive_bloating(X, P, M, stage_weights, recorsive=False)
P,
M,
stage_weights,
recorsive=False)
else: else:
simulated_pieces = compute_piece_counts(X, P, stage_weights) simulated_pieces = compute_piece_counts(X, P, stage_weights)
simulated_deployment = jsq_placement(X, simulated_pieces, M, simulated_deployment = jsq_placement(X, simulated_pieces, M, stage_weights)
stage_weights)
pieces = group_based_adaptive_bloating_kernel( pieces = group_based_adaptive_bloating_kernel(
X.astype(np.float32), X.astype(np.float32),
@@ -459,10 +431,7 @@ class FlashLB(EplbPolicy):
deployment = jsq_placement(X, pieces, M, stage_weights) deployment = jsq_placement(X, pieces, M, stage_weights)
X_all = X.sum(0) X_all = X.sum(0)
unit_load = np.divide(X_all, unit_load = np.divide(X_all, pieces, out=np.zeros_like(X_all, dtype=float), where=pieces != 0)
pieces,
out=np.zeros_like(X_all, dtype=float),
where=pieces != 0)
load = unit_load[deployment].sum(-1) load = unit_load[deployment].sum(-1)
sim_unit_load = X_all / simulated_pieces sim_unit_load = X_all / simulated_pieces
@@ -510,16 +479,13 @@ class FlashLB(EplbPolicy):
def register_hotness(self, deployment, rank_load, num_layer, num_expert): def register_hotness(self, deployment, rank_load, num_layer, num_expert):
for layer in range(num_layer): for layer in range(num_layer):
if layer not in self.hotness_window: if layer not in self.hotness_window:
self.hotness_window[layer] = deque( self.hotness_window[layer] = deque(maxlen=self.max_stage_window)
maxlen=self.max_stage_window) hotness = self.compute_expert_hotness(num_expert, deployment[layer], rank_load[layer])
hotness = self.compute_expert_hotness(num_expert,
deployment[layer],
rank_load[layer])
self.hotness_window[layer].append(hotness) self.hotness_window[layer].append(hotness)
def compress_by_avg_pooling_fast_nd(self, arr, m): def compress_by_avg_pooling_fast_nd(self, arr, m):
n, d = arr.shape n, d = arr.shape
idx = (np.arange(n) * m // n) idx = np.arange(n) * m // n
result = np.zeros((m, d)) result = np.zeros((m, d))
counts = np.zeros((m, 1)) counts = np.zeros((m, 1))
np.add.at(result, idx, arr) np.add.at(result, idx, arr)
@@ -532,8 +498,7 @@ class FlashLB(EplbPolicy):
expert_workload += 1 expert_workload += 1
num_layer = expert_workload.shape[0] num_layer = expert_workload.shape[0]
num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0] num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0]
self.register_hotness(current_deployment, expert_workload, num_layer, self.register_hotness(current_deployment, expert_workload, num_layer, num_expert)
num_expert)
new_deployment = current_deployment.copy() new_deployment = current_deployment.copy()
@@ -544,21 +509,17 @@ class FlashLB(EplbPolicy):
for i, layer in enumerate(layers_need_update): for i, layer in enumerate(layers_need_update):
hotness = np.array(self.hotness_window[layer]) hotness = np.array(self.hotness_window[layer])
if hotness.shape[0] > self.max_stage_window: if hotness.shape[0] > self.max_stage_window:
hotness = self.compress_by_avg_pooling_fast_nd( hotness = self.compress_by_avg_pooling_fast_nd(hotness, self.max_stage_window)
hotness, self.max_stage_window)
( (
new_deployment[layer], new_deployment[layer],
new_par[i], new_par[i],
current_par[i], current_par[i],
) = self.rebalance_layer(current_deployment[layer], ) = self.rebalance_layer(current_deployment[layer], hotness, layer_id=layer)
hotness,
layer_id=layer)
priority = new_par / current_par priority = new_par / current_par
priority_idx = np.argsort(priority) priority_idx = np.argsort(priority)
priority_idx = priority_idx[priority[priority_idx] < priority_idx = priority_idx[priority[priority_idx] < 1][: self.buffer_expert_layer_num]
1][:self.buffer_expert_layer_num]
if np.all(expert_workload == 1): if np.all(expert_workload == 1):
for _, layer in enumerate(layers_need_update): for _, layer in enumerate(layers_need_update):
@@ -572,16 +533,12 @@ class FlashLB(EplbPolicy):
layers_need_update = priority_idx layers_need_update = priority_idx
deployment = current_deployment deployment = current_deployment
for layer in layers_need_update: for layer in layers_need_update:
deployment[layer] = auto_fix_new_placement( deployment[layer] = auto_fix_new_placement(current_deployment[layer], new_deployment[layer])
current_deployment[layer], new_deployment[layer])
return change, layers_need_update, deployment return change, layers_need_update, deployment
def generate_layered_experts(num_layers=58, def generate_layered_experts(num_layers=58, layer_shape=(32, 9), expert_min=0, expert_max=255):
layer_shape=(32, 9),
expert_min=0,
expert_max=255):
""" """
Generate expert deployment matrix meeting the following conditions: Generate expert deployment matrix meeting the following conditions:
- Total of num_layers layers - Total of num_layers layers
@@ -598,32 +555,25 @@ def generate_layered_experts(num_layers=58,
""" """
# 1. Basic parameter calculation # 1. Basic parameter calculation
expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255) expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255)
layer_total = layer_shape[0] * layer_shape[ layer_total = layer_shape[0] * layer_shape[1] # Total elements in a single layer: 32*9=288
1] # Total elements in a single layer: 32*9=288
extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32 extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32
# 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts) # 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts)
assert layer_total >= expert_num, ( assert layer_total >= expert_num, (
f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, " f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, cannot cover all experts"
"cannot cover all experts") )
# 3. Generate layers one by one # 3. Generate layers one by one
layers = [] layers = []
for _ in range(num_layers): for _ in range(num_layers):
# 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included) # 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included)
full_experts = torch.arange(expert_min, full_experts = torch.arange(expert_min, expert_max + 1, dtype=torch.int64) # shape (256,)
expert_max + 1,
dtype=torch.int64) # shape (256,)
# 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255) # 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255)
extra_experts = torch.randint(expert_min, extra_experts = torch.randint(expert_min, expert_max + 1, size=(extra_slots,), dtype=torch.int64) # shape (32,)
expert_max + 1,
size=(extra_slots, ),
dtype=torch.int64) # shape (32,)
# 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer) # 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer)
layer_flat = torch.cat([full_experts, extra_experts], layer_flat = torch.cat([full_experts, extra_experts], dim=0) # shape (288,)
dim=0) # shape (288,)
# Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues) # Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues)
shuffle_idx = torch.randperm(layer_flat.shape[0]) shuffle_idx = torch.randperm(layer_flat.shape[0])
layer_shuffled = layer_flat[shuffle_idx] layer_shuffled = layer_flat[shuffle_idx]
@@ -642,7 +592,6 @@ def warm_up():
exam_config.num_die_per_host = 16 exam_config.num_die_per_host = 16
algo = FlashLB(exam_config) algo = FlashLB(exam_config)
# Generate target tensor # Generate target tensor
expert_tensor = generate_layered_experts(num_layers=58, expert_tensor = generate_layered_experts(num_layers=58, layer_shape=(32, 9))
layer_shape=(32, 9))
algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9))) algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))

View File

@@ -9,7 +9,6 @@ random.seed(42)
class RandomLoadBalance(EplbPolicy): class RandomLoadBalance(EplbPolicy):
def __init__(self, config: DynamicConfig): def __init__(self, config: DynamicConfig):
super().__init__(config) super().__init__(config)

View File

@@ -16,7 +16,6 @@ class DynamicConfig:
class EplbPolicy: class EplbPolicy:
def __init__(self, config: DynamicConfig): def __init__(self, config: DynamicConfig):
self.config = config self.config = config
@@ -63,7 +62,6 @@ class DynamicTable:
class SwiftBalanceEplb(EplbPolicy): class SwiftBalanceEplb(EplbPolicy):
def __init__(self, config: DynamicConfig): def __init__(self, config: DynamicConfig):
super().__init__(config) super().__init__(config)
@@ -89,8 +87,7 @@ class SwiftBalanceEplb(EplbPolicy):
return a % b return a % b
@staticmethod @staticmethod
def add_redundant(current_expert_table, expert_workload, def add_redundant(current_expert_table, expert_workload, num_original_expert):
num_original_expert):
layer_num, npu_num, experts_per_npu = expert_workload.shape layer_num, npu_num, experts_per_npu = expert_workload.shape
workload_new = np.zeros((layer_num, num_original_expert)) workload_new = np.zeros((layer_num, num_original_expert))
for layer_idx in range(layer_num): for layer_idx in range(layer_num):
@@ -99,8 +96,7 @@ class SwiftBalanceEplb(EplbPolicy):
workload_layer = expert_workload[layer_idx].copy() workload_layer = expert_workload[layer_idx].copy()
for npu_idx in range(npu_num): for npu_idx in range(npu_num):
for expert_idx in range(experts_per_npu): for expert_idx in range(experts_per_npu):
workload_dict[placement_layer[npu_idx][ workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx]
expert_idx]] += workload_layer[npu_idx][expert_idx]
for expert_idx in range(num_original_expert): for expert_idx in range(num_original_expert):
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
return workload_new return workload_new
@@ -118,9 +114,7 @@ class SwiftBalanceEplb(EplbPolicy):
max_heat_per_layer.append(np.max(npu_heats_now)) max_heat_per_layer.append(np.max(npu_heats_now))
return max_heat_per_layer return max_heat_per_layer
def calculate_initial_imbalance(self, global_deployment, def calculate_initial_imbalance(self, global_deployment, new_layer_workloads):
new_layer_workloads):
device_num = global_deployment.shape[1] device_num = global_deployment.shape[1]
layer_imbalance = [] layer_imbalance = []
expert_num = np.zeros_like(new_layer_workloads) expert_num = np.zeros_like(new_layer_workloads)
@@ -136,56 +130,54 @@ class SwiftBalanceEplb(EplbPolicy):
box_workload = 0 box_workload = 0
for expert_id in box: for expert_id in box:
update_workload = self.safe_divide( update_workload = self.safe_divide(
new_layer_workloads[layer_id][expert_id], new_layer_workloads[layer_id][expert_id], expert_num[layer_id][expert_id]
expert_num[layer_id][expert_id]) )
box_workload += update_workload box_workload += update_workload
total_workload += update_workload total_workload += update_workload
if cur_layer_max_workload < box_workload: if cur_layer_max_workload < box_workload:
cur_layer_max_workload = box_workload cur_layer_max_workload = box_workload
cur_layer_imbalance = self.safe_divide( cur_layer_imbalance = self.safe_divide(
cur_layer_max_workload, cur_layer_max_workload, (self.safe_divide(total_workload, device_num))
(self.safe_divide(total_workload, device_num))) )
layer_imbalance.append(cur_layer_imbalance) layer_imbalance.append(cur_layer_imbalance)
return layer_imbalance return layer_imbalance
def compute_redundant_assignments(self, base_experts, def compute_redundant_assignments(self, base_experts, num_redundant_experts, num_experts):
num_redundant_experts, num_experts): redundant_assignments: list[list[int]] = [[] for _ in range(num_experts)]
redundant_assignments: list[list[int]] = [[]
for _ in range(num_experts)]
current_weights = base_experts.copy() current_weights = base_experts.copy()
for i in range(num_redundant_experts): for i in range(num_redundant_experts):
sorted_indices = np.argsort([w for _, w in current_weights], sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1]
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices] sorted_weights = [current_weights[i] for i in sorted_indices]
target_expert = sorted_weights[0] target_expert = sorted_weights[0]
expert_id, original_weight = target_expert expert_id, original_weight = target_expert
current_redundancy = len(redundant_assignments[expert_id]) current_redundancy = len(redundant_assignments[expert_id])
new_avg_weight = self.safe_divide( new_avg_weight = self.safe_divide(original_weight * (current_redundancy + 1), (current_redundancy + 2))
original_weight * (current_redundancy + 1),
(current_redundancy + 2))
redundant_assignments[expert_id].append(num_experts + i) redundant_assignments[expert_id].append(num_experts + i)
current_weights[sorted_indices[0]] = (expert_id, new_avg_weight) current_weights[sorted_indices[0]] = (expert_id, new_avg_weight)
sorted_indices = np.argsort([w for _, w in current_weights], sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1]
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices] sorted_weights = [current_weights[i] for i in sorted_indices]
return redundant_assignments, sorted_weights return redundant_assignments, sorted_weights
def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos, def repeat_compute_redundant_assignments(
num_experts, num_exist_expert, self,
device_assignments, device_counts, layer_workloads,
expert_from_device, rendun_pos,
com_between_devices): num_experts,
num_exist_expert,
current_weights = np.zeros((num_experts, ), dtype='object') device_assignments,
device_counts,
expert_from_device,
com_between_devices,
):
current_weights = np.zeros((num_experts,), dtype="object")
for expert_id, workload_weight in enumerate(layer_workloads): for expert_id, workload_weight in enumerate(layer_workloads):
current_weights[expert_id] = (expert_id, workload_weight) current_weights[expert_id] = (expert_id, workload_weight)
@@ -195,8 +187,7 @@ class SwiftBalanceEplb(EplbPolicy):
devices_with_slots.append(device_id) devices_with_slots.append(device_id)
while devices_with_slots: while devices_with_slots:
sorted_indices = np.argsort([w for _, w in current_weights], sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1]
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices] sorted_weights = [current_weights[i] for i in sorted_indices]
for index, target_weight in enumerate(sorted_weights): for index, target_weight in enumerate(sorted_weights):
@@ -210,17 +201,15 @@ class SwiftBalanceEplb(EplbPolicy):
pos = rendun_pos[cur_device_id].pop() pos = rendun_pos[cur_device_id].pop()
if len(rendun_pos[cur_device_id]) == 0: if len(rendun_pos[cur_device_id]) == 0:
devices_with_slots = [ devices_with_slots = [
device_id for device_id in devices_with_slots device_id for device_id in devices_with_slots if device_id != cur_device_id
if device_id != cur_device_id
] ]
device_assignments[cur_device_id][pos] = expert_id device_assignments[cur_device_id][pos] = expert_id
device_counts[cur_device_id] += 1 device_counts[cur_device_id] += 1
communication_box_index = expert_from_device[expert_id] communication_box_index = expert_from_device[expert_id]
com_between_devices[cur_device_id][ com_between_devices[cur_device_id][communication_box_index] = expert_id
communication_box_index] = expert_id
new_weight = self.safe_divide( new_weight = self.safe_divide(
(original_weight * num_exist_expert[expert_id]), (original_weight * num_exist_expert[expert_id]), (num_exist_expert[expert_id] + 1)
(num_exist_expert[expert_id] + 1)) )
sorted_weights[index] = (expert_id, new_weight) sorted_weights[index] = (expert_id, new_weight)
num_exist_expert[expert_id] += 1 num_exist_expert[expert_id] += 1
redundancy_successful = True redundancy_successful = True
@@ -228,41 +217,31 @@ class SwiftBalanceEplb(EplbPolicy):
if redundancy_successful: if redundancy_successful:
break break
sorted_indices = np.argsort([id for id, _ in sorted_weights], sorted_indices = np.argsort([id for id, _ in sorted_weights], kind="stable")
kind='stable')
sorted_weights = [sorted_weights[i][1] for i in sorted_indices] sorted_weights = [sorted_weights[i][1] for i in sorted_indices]
return sorted_weights, device_assignments, device_counts, com_between_devices return sorted_weights, device_assignments, device_counts, com_between_devices
@staticmethod @staticmethod
def prepare_expert_list(base_experts, redundant_assignments, def prepare_expert_list(base_experts, redundant_assignments, num_redundant_experts):
num_redundant_experts):
redundant_expert_list = np.empty(num_redundant_experts, dtype=object) redundant_expert_list = np.empty(num_redundant_experts, dtype=object)
index = 0 index = 0
num_experts = len(redundant_assignments) num_experts = len(redundant_assignments)
for expert_id in range(num_experts): for expert_id in range(num_experts):
for _ in redundant_assignments[expert_id]: for _ in redundant_assignments[expert_id]:
redundant_expert_list[index] = (expert_id, redundant_expert_list[index] = (expert_id, next(w for eid, w in base_experts if eid == expert_id))
next(w
for eid, w in base_experts
if eid == expert_id))
index += 1 index += 1
sorted_indices = np.argsort([w for _, w in redundant_expert_list], sorted_indices = np.argsort([w for _, w in redundant_expert_list], kind="stable")[::-1]
kind='stable')[::-1]
return [redundant_expert_list[i] for i in sorted_indices] return [redundant_expert_list[i] for i in sorted_indices]
@staticmethod @staticmethod
def non_redundant_expert_information(origin_deployment, updated_weights, def non_redundant_expert_information(origin_deployment, updated_weights, rendun_pos):
rendun_pos):
device_num = len(origin_deployment) device_num = len(origin_deployment)
num_experts_per_device = origin_deployment.shape[1] num_experts_per_device = origin_deployment.shape[1]
device_assignments = [[-1 for _ in range(num_experts_per_device)] device_assignments = [[-1 for _ in range(num_experts_per_device)] for _ in range(device_num)]
for _ in range(device_num)] device_weights = [[0 for _ in range(num_experts_per_device)] for _ in range(device_num)]
device_weights = [[0 for _ in range(num_experts_per_device)]
for _ in range(device_num)]
device_loads = [0] * device_num device_loads = [0] * device_num
device_counts = [0] * device_num device_counts = [0] * device_num
@@ -272,8 +251,8 @@ class SwiftBalanceEplb(EplbPolicy):
continue continue
device_assignments[device_id][index] = expert_id device_assignments[device_id][index] = expert_id
cur_weight = next( cur_weight = next(
weight for expert_id_of_weight, weight in updated_weights weight for expert_id_of_weight, weight in updated_weights if expert_id_of_weight == expert_id
if expert_id_of_weight == expert_id) )
device_weights[device_id][index] = cur_weight device_weights[device_id][index] = cur_weight
device_loads[device_id] += cur_weight device_loads[device_id] += cur_weight
device_counts[device_id] += 1 device_counts[device_id] += 1
@@ -292,19 +271,24 @@ class SwiftBalanceEplb(EplbPolicy):
if num_all_experts[expert_id] == 0: if num_all_experts[expert_id] == 0:
cur_layer_workload.append(-1) cur_layer_workload.append(-1)
else: else:
cur_layer_workload.append( cur_layer_workload.append(self.safe_divide(weight, num_all_experts[expert_id]))
self.safe_divide(weight, num_all_experts[expert_id]))
return cur_layer_workload, num_all_experts return cur_layer_workload, num_all_experts
def distribute_redun_experts(self, layer_workloads, device_assignments, def distribute_redun_experts(
device_weights, device_loads, device_counts, self,
redundant_expert_list, expert_from_device, layer_workloads,
num_experts, rendun_pos): device_assignments,
device_weights,
device_loads,
device_counts,
redundant_expert_list,
expert_from_device,
num_experts,
rendun_pos,
):
num_devices = len(device_assignments) num_devices = len(device_assignments)
com_between_devices: list[dict[int, com_between_devices: list[dict[int, int]] = [{} for _ in range(num_devices)]
int]] = [{} for _ in range(num_devices)]
for expert_id, weight in redundant_expert_list: for expert_id, weight in redundant_expert_list:
candidate = -1 candidate = -1
@@ -313,8 +297,7 @@ class SwiftBalanceEplb(EplbPolicy):
continue continue
if expert_id in device_assignments[dev_id]: if expert_id in device_assignments[dev_id]:
continue continue
if candidate == -1 or device_loads[dev_id] < device_loads[ if candidate == -1 or device_loads[dev_id] < device_loads[candidate]:
candidate]:
candidate = dev_id candidate = dev_id
if candidate != -1: if candidate != -1:
pos = rendun_pos[candidate].pop() pos = rendun_pos[candidate].pop()
@@ -324,31 +307,42 @@ class SwiftBalanceEplb(EplbPolicy):
device_counts[candidate] += 1 device_counts[candidate] += 1
communication_box_index = expert_from_device[expert_id] communication_box_index = expert_from_device[expert_id]
com_between_devices[candidate][ com_between_devices[candidate][communication_box_index] = expert_id
communication_box_index] = expert_id
if any(sublist for sublist in rendun_pos): if any(sublist for sublist in rendun_pos):
cur_layer_workload, num_exist_expert = self.recomputing_initial_weight( cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(layer_workloads, device_assignments)
layer_workloads, device_assignments)
update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments( update_workload, device_assignments, device_counts, com_between_devices = (
cur_layer_workload, rendun_pos, num_experts, num_exist_expert, self.repeat_compute_redundant_assignments(
device_assignments, device_loads, expert_from_device, cur_layer_workload,
com_between_devices) rendun_pos,
num_experts,
num_exist_expert,
device_assignments,
device_loads,
expert_from_device,
com_between_devices,
)
)
device_loads = [0] * len(device_counts) device_loads = [0] * len(device_counts)
for device_id, device in enumerate(device_assignments): for device_id, device in enumerate(device_assignments):
for index, expert_id in enumerate(device): for index, expert_id in enumerate(device):
device_weights[device_id][index] = update_workload[ device_weights[device_id][index] = update_workload[expert_id]
expert_id]
device_loads[device_id] += update_workload[expert_id] device_loads[device_id] += update_workload[expert_id]
return device_assignments, device_weights, device_loads, device_counts, com_between_devices return device_assignments, device_weights, device_loads, device_counts, com_between_devices
def redundancy_again(self, layer_workloads, origin_weights, def redundancy_again(
origin_deployment, expert_from_device, num_node, self,
is_node_redundant, rendun_pos): layer_workloads,
origin_weights,
origin_deployment,
expert_from_device,
num_node,
is_node_redundant,
rendun_pos,
):
num_experts = len(origin_weights) num_experts = len(origin_weights)
if is_node_redundant: if is_node_redundant:
num_experts = num_experts * num_node num_experts = num_experts * num_node
@@ -358,25 +352,33 @@ class SwiftBalanceEplb(EplbPolicy):
num_redundant_experts += len(rank_empty_pos) num_redundant_experts += len(rank_empty_pos)
redundant_assignments, updated_weights = self.compute_redundant_assignments( redundant_assignments, updated_weights = self.compute_redundant_assignments(
origin_weights, num_redundant_experts, num_experts) origin_weights, num_redundant_experts, num_experts
)
redundant_expert_list = self.prepare_expert_list( redundant_expert_list = self.prepare_expert_list(updated_weights, redundant_assignments, num_redundant_experts)
updated_weights, redundant_assignments, num_redundant_experts)
device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information( device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information(
origin_deployment, updated_weights, rendun_pos) origin_deployment, updated_weights, rendun_pos
)
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts( device_assignments, device_weights, device_loads, device_counts, com_between_devices = (
layer_workloads, device_assignments, device_weights, device_loads, self.distribute_redun_experts(
device_counts, redundant_expert_list, expert_from_device, layer_workloads,
num_experts, rendun_pos) device_assignments,
device_weights,
device_loads,
device_counts,
redundant_expert_list,
expert_from_device,
num_experts,
rendun_pos,
)
)
return device_assignments, device_weights, device_loads, device_counts, com_between_devices return device_assignments, device_weights, device_loads, device_counts, com_between_devices
@staticmethod @staticmethod
def generate_allocation_report(device_assignments, device_weights, def generate_allocation_report(device_assignments, device_weights, device_loads, device_counts):
device_loads, device_counts):
report = [] report = []
max_load = 0.0 max_load = 0.0
@@ -384,27 +386,27 @@ class SwiftBalanceEplb(EplbPolicy):
current_load = device_loads[dev_id] current_load = device_loads[dev_id]
max_load = max(max_load, current_load) max_load = max(max_load, current_load)
report.append({ report.append(
"device_id": dev_id + 1, {
"assigned_experts": device_assignments[dev_id], "device_id": dev_id + 1,
"expert_weights": device_weights[dev_id], "assigned_experts": device_assignments[dev_id],
"total_load": current_load, "expert_weights": device_weights[dev_id],
"expert_count": device_counts[dev_id] "total_load": current_load,
}) "expert_count": device_counts[dev_id],
}
)
return report, max_load return report, max_load
@staticmethod @staticmethod
def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id, def exchange_expert(
next_device_id, cur_layer_result, com_between_devices): cur_exchange_index, next_exchange_index, cur_device_id, next_device_id, cur_layer_result, com_between_devices
):
cur_device_deployment = cur_layer_result[cur_device_id]["assigned_experts"]
next_device_deployment = cur_layer_result[next_device_id]["assigned_experts"]
cur_device_deployment = cur_layer_result[cur_device_id][ cur_device_weight = cur_layer_result[cur_device_id]["expert_weights"]
'assigned_experts'] next_device_weight = cur_layer_result[next_device_id]["expert_weights"]
next_device_deployment = cur_layer_result[next_device_id][
'assigned_experts']
cur_device_weight = cur_layer_result[cur_device_id]['expert_weights']
next_device_weight = cur_layer_result[next_device_id]['expert_weights']
cur_expert_id = cur_device_deployment[cur_exchange_index] cur_expert_id = cur_device_deployment[cur_exchange_index]
next_expert_id = next_device_deployment[next_exchange_index] next_expert_id = next_device_deployment[next_exchange_index]
@@ -416,29 +418,25 @@ class SwiftBalanceEplb(EplbPolicy):
cur_device_weight[cur_exchange_index] = next_expert_weight cur_device_weight[cur_exchange_index] = next_expert_weight
next_device_weight[next_exchange_index] = cur_expert_weight next_device_weight[next_exchange_index] = cur_expert_weight
cur_layer_result[cur_device_id][ cur_layer_result[cur_device_id]["total_load"] += next_expert_weight - cur_expert_weight
'total_load'] += next_expert_weight - cur_expert_weight cur_layer_result[next_device_id]["total_load"] += cur_expert_weight - next_expert_weight
cur_layer_result[next_device_id][
'total_load'] += cur_expert_weight - next_expert_weight
com_between_devices[cur_device_id][next_device_id] = next_expert_id com_between_devices[cur_device_id][next_device_id] = next_expert_id
com_between_devices[next_device_id][cur_device_id] = cur_expert_id com_between_devices[next_device_id][cur_device_id] = cur_expert_id
def redundant_expert_deployment(self, layer_workloads, original_deployment, def redundant_expert_deployment(
expert_from_device, node_num, self, layer_workloads, original_deployment, expert_from_device, node_num, is_node_redundant, rendun_pos
is_node_redundant, rendun_pos): ):
device_num, per_device_expert_num = original_deployment.shape device_num, per_device_expert_num = original_deployment.shape
route_expert_num = layer_workloads.shape[0] route_expert_num = layer_workloads.shape[0]
per_node_device_num = self.safe_exact_divide(device_num, node_num) per_node_device_num = self.safe_exact_divide(device_num, node_num)
per_node_route_expert_num = per_node_device_num * ( per_node_route_expert_num = per_node_device_num * (per_device_expert_num - 1)
per_device_expert_num - 1)
weights = np.zeros((route_expert_num, ), dtype='object') weights = np.zeros((route_expert_num,), dtype="object")
for expert_id, workload_weight in enumerate(layer_workloads): for expert_id, workload_weight in enumerate(layer_workloads):
weights[expert_id] = (expert_id, workload_weight) weights[expert_id] = (expert_id, workload_weight)
if is_node_redundant: if is_node_redundant:
device_assignments = [] device_assignments = []
device_weights = [] device_weights = []
device_loads = [] device_loads = []
@@ -446,23 +444,30 @@ class SwiftBalanceEplb(EplbPolicy):
com_between_devices = [] com_between_devices = []
for node_id in range(node_num): for node_id in range(node_num):
cur_node_weights = weights[node_id * cur_node_weights = weights[
per_node_route_expert_num:(node_id + node_id * per_node_route_expert_num : (node_id + 1) * per_node_route_expert_num
1) * ]
per_node_route_expert_num]
cur_original_deployment = original_deployment[ cur_original_deployment = original_deployment[
node_id * per_node_device_num:(node_id + 1) * node_id * per_node_device_num : (node_id + 1) * per_node_device_num
per_node_device_num] ]
cur_node_rendun_pos = rendun_pos[node_id * cur_node_rendun_pos = rendun_pos[node_id * per_node_device_num : (node_id + 1) * per_node_device_num]
per_node_device_num:(node_id +
1) *
per_node_device_num]
cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again( (
layer_workloads, cur_node_weights, cur_original_deployment, cur_device_assignments,
expert_from_device, node_num, is_node_redundant, cur_device_weights,
cur_node_rendun_pos) cur_device_loads,
cur_device_counts,
cur_com_between_devices,
) = self.redundancy_again(
layer_workloads,
cur_node_weights,
cur_original_deployment,
expert_from_device,
node_num,
is_node_redundant,
cur_node_rendun_pos,
)
device_assignments += cur_device_assignments device_assignments += cur_device_assignments
device_weights += cur_device_weights device_weights += cur_device_weights
device_loads += cur_device_loads device_loads += cur_device_loads
@@ -470,28 +475,41 @@ class SwiftBalanceEplb(EplbPolicy):
com_between_devices += cur_com_between_devices com_between_devices += cur_com_between_devices
else: else:
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again( device_assignments, device_weights, device_loads, device_counts, com_between_devices = (
layer_workloads, weights, original_deployment, self.redundancy_again(
expert_from_device, node_num, is_node_redundant, rendun_pos) layer_workloads,
weights,
original_deployment,
expert_from_device,
node_num,
is_node_redundant,
rendun_pos,
)
)
report, max_load = self.generate_allocation_report( report, max_load = self.generate_allocation_report(
device_assignments, device_weights, device_loads, device_counts) device_assignments, device_weights, device_loads, device_counts
)
return report, max_load, com_between_devices return report, max_load, com_between_devices
@staticmethod @staticmethod
def two_device_exchange_experts(cur_device_result, exchange_device_result, def two_device_exchange_experts(
cur_exchanged_expert_id, cur_device_result,
next_exchanged_expert_id, ave_workload, exchange_device_result,
increment, num_redundancy_expert): cur_exchanged_expert_id,
next_exchanged_expert_id,
ave_workload,
increment,
num_redundancy_expert,
):
cur_device_weight = cur_device_result["expert_weights"]
next_device_weight = exchange_device_result["expert_weights"]
cur_device_weight = cur_device_result['expert_weights'] cur_device_expert_id = cur_device_result["assigned_experts"]
next_device_weight = exchange_device_result['expert_weights'] next_device_expert_id = exchange_device_result["assigned_experts"]
cur_device_expert_id = cur_device_result['assigned_experts'] cur_device_total_weight = cur_device_result["total_load"]
next_device_expert_id = exchange_device_result['assigned_experts'] next_device_total_weight = exchange_device_result["total_load"]
cur_device_total_weight = cur_device_result['total_load']
next_device_total_weight = exchange_device_result['total_load']
max_weight = max(cur_device_total_weight, next_device_total_weight) max_weight = max(cur_device_total_weight, next_device_total_weight)
cur_exchange_index = -1 cur_exchange_index = -1
@@ -500,50 +518,47 @@ class SwiftBalanceEplb(EplbPolicy):
for index, weight in enumerate(cur_device_weight): for index, weight in enumerate(cur_device_weight):
for next_index, next_weight in enumerate(next_device_weight): for next_index, next_weight in enumerate(next_device_weight):
change_flag = True change_flag = True
if (cur_device_expert_id[index] in next_device_expert_id if (
or next_device_expert_id[next_index] cur_device_expert_id[index] in next_device_expert_id
in cur_device_expert_id): or next_device_expert_id[next_index] in cur_device_expert_id
):
change_flag = False change_flag = False
if (cur_device_expert_id[index] not in cur_exchanged_expert_id if (
) and (next_device_expert_id[next_index] (cur_device_expert_id[index] not in cur_exchanged_expert_id)
not in next_exchanged_expert_id) and change_flag: and (next_device_expert_id[next_index] not in next_exchanged_expert_id)
and change_flag
):
cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight
next_total_weight_after_exchange = next_device_total_weight - next_weight + weight next_total_weight_after_exchange = next_device_total_weight - next_weight + weight
exchange_max_weight = max( exchange_max_weight = max(cur_total_weight_after_exchange, next_total_weight_after_exchange)
cur_total_weight_after_exchange, if exchange_max_weight < max_weight and (max_weight - exchange_max_weight) >= (
next_total_weight_after_exchange) ave_workload * increment
if exchange_max_weight < max_weight and ( ):
max_weight -
exchange_max_weight) >= (ave_workload * increment):
max_weight = exchange_max_weight max_weight = exchange_max_weight
cur_exchange_index = index cur_exchange_index = index
next_exchange_index = next_index next_exchange_index = next_index
return cur_exchange_index, next_exchange_index return cur_exchange_index, next_exchange_index
def expert_exchange_between_devices(self, def expert_exchange_between_devices(
ave_workload, self,
increment, ave_workload,
cur_layer_result, increment,
com_between_devices, cur_layer_result,
num_redundancy_expert, com_between_devices,
node_idx=0, num_redundancy_expert,
per_node_device_num=0, node_idx=0,
is_node_redundant=False): per_node_device_num=0,
is_node_redundant=False,
):
if is_node_redundant: if is_node_redundant:
cur_devices_result = cur_layer_result[node_idx * cur_devices_result = cur_layer_result[node_idx * per_node_device_num : (node_idx + 1) * per_node_device_num]
per_node_device_num:
(node_idx + 1) *
per_node_device_num]
else: else:
cur_devices_result = cur_layer_result cur_devices_result = cur_layer_result
devices_total_weight = [] devices_total_weight = []
for device in cur_devices_result: for device in cur_devices_result:
devices_total_weight.append( devices_total_weight.append((device["total_load"], device["device_id"] - 1))
(device['total_load'], device['device_id'] - 1))
exchange_frequency = 100 exchange_frequency = 100
while exchange_frequency > 0: while exchange_frequency > 0:
@@ -553,64 +568,81 @@ class SwiftBalanceEplb(EplbPolicy):
exchange = False exchange = False
for index in range(0, len(devices_total_weight) - 1): for index in range(0, len(devices_total_weight) - 1):
min_weight_device_id = devices_total_weight[index][1] min_weight_device_id = devices_total_weight[index][1]
if min_weight_device_id not in com_between_devices[ if min_weight_device_id not in com_between_devices[max_weight_device_id]:
max_weight_device_id]: cur_exchanged_expert_id = list(com_between_devices[max_weight_device_id].values())
cur_exchanged_expert_id = list( next_exchanged_expert_id = list(com_between_devices[min_weight_device_id].values())
com_between_devices[max_weight_device_id].values())
next_exchanged_expert_id = list(
com_between_devices[min_weight_device_id].values())
cur_exchange_index, next_exchange_index = self.two_device_exchange_experts( cur_exchange_index, next_exchange_index = self.two_device_exchange_experts(
cur_layer_result[max_weight_device_id], cur_layer_result[max_weight_device_id],
cur_layer_result[min_weight_device_id], cur_layer_result[min_weight_device_id],
cur_exchanged_expert_id, next_exchanged_expert_id, cur_exchanged_expert_id,
ave_workload, increment, num_redundancy_expert) next_exchanged_expert_id,
ave_workload,
increment,
num_redundancy_expert,
)
if cur_exchange_index != -1: if cur_exchange_index != -1:
self.exchange_expert(cur_exchange_index, self.exchange_expert(
next_exchange_index, cur_exchange_index,
max_weight_device_id, next_exchange_index,
min_weight_device_id, max_weight_device_id,
cur_layer_result, min_weight_device_id,
com_between_devices) cur_layer_result,
com_between_devices,
)
devices_total_weight[-1] = ( devices_total_weight[-1] = (
cur_layer_result[max_weight_device_id] cur_layer_result[max_weight_device_id]["total_load"],
['total_load'], max_weight_device_id) max_weight_device_id,
)
devices_total_weight[index] = ( devices_total_weight[index] = (
cur_layer_result[min_weight_device_id] cur_layer_result[min_weight_device_id]["total_load"],
['total_load'], min_weight_device_id) min_weight_device_id,
)
exchange = True exchange = True
break break
if not exchange: if not exchange:
break break
def exchange_experts(self, layer_result, layer_com_between_devices, def exchange_experts(
num_nodes, device_num, is_node_redundant, self,
ave_workload, increment, num_redundancy_expert, layer_result,
org_deployment): layer_com_between_devices,
num_nodes,
device_num,
is_node_redundant,
ave_workload,
increment,
num_redundancy_expert,
org_deployment,
):
global_deployment = [] global_deployment = []
if is_node_redundant: if is_node_redundant:
per_node_device_num = self.safe_exact_divide(device_num, num_nodes) per_node_device_num = self.safe_exact_divide(device_num, num_nodes)
for node_idx in range(num_nodes): for node_idx in range(num_nodes):
self.expert_exchange_between_devices( self.expert_exchange_between_devices(
ave_workload, increment, layer_result, ave_workload,
layer_com_between_devices, num_redundancy_expert, node_idx, increment,
per_node_device_num, is_node_redundant) layer_result,
layer_com_between_devices,
num_redundancy_expert,
node_idx,
per_node_device_num,
is_node_redundant,
)
else: else:
self.expert_exchange_between_devices(ave_workload, increment, self.expert_exchange_between_devices(
layer_result, ave_workload, increment, layer_result, layer_com_between_devices, num_redundancy_expert
layer_com_between_devices, )
num_redundancy_expert)
max_workload = 0 max_workload = 0
for box in layer_result: for box in layer_result:
global_deployment.append(box['assigned_experts']) global_deployment.append(box["assigned_experts"])
if max_workload < box['total_load']: if max_workload < box["total_load"]:
max_workload = box['total_load'] max_workload = box["total_load"]
global_deployment = np.array(global_deployment) global_deployment = np.array(global_deployment)
@@ -626,16 +658,11 @@ class SwiftBalanceEplb(EplbPolicy):
return count return count
@staticmethod @staticmethod
def constraint_expert_local_exchange(current_expert_table, def constraint_expert_local_exchange(current_expert_table, global_deployment):
global_deployment):
for layer_id in range(len(global_deployment)): for layer_id in range(len(global_deployment)):
for card_id in range(len(global_deployment[layer_id])): for card_id in range(len(global_deployment[layer_id])):
current_list = [ current_list = [int(x) for x in current_expert_table[layer_id][card_id]]
int(x) for x in current_expert_table[layer_id][card_id] new_list = [int(x) for x in global_deployment[layer_id][card_id]]
]
new_list = [
int(x) for x in global_deployment[layer_id][card_id]
]
num = len(new_list) num = len(new_list)
new_index = [-1] * num new_index = [-1] * num
@@ -645,8 +672,7 @@ class SwiftBalanceEplb(EplbPolicy):
for i in range(num): for i in range(num):
flag = True flag = True
for j in range(num): for j in range(num):
if new_list[i] == current_list[j] and new_index[ if new_list[i] == current_list[j] and new_index[j] == -1:
j] == -1:
new_index[j] = 0 new_index[j] = 0
new_result[j] = current_list[j] new_result[j] = current_list[j]
flag = False flag = False
@@ -664,25 +690,17 @@ class SwiftBalanceEplb(EplbPolicy):
return global_deployment return global_deployment
def rebalance_experts(self, def rebalance_experts(self, current_expert_table, expert_workload, is_node_redundant=False, increment=0.01):
current_expert_table,
expert_workload,
is_node_redundant=False,
increment=0.01):
info = DynamicTable() info = DynamicTable()
info.workload_table = expert_workload.numpy() info.workload_table = expert_workload.numpy()
info.placement_table = current_expert_table.numpy() info.placement_table = current_expert_table.numpy()
assert info.workload_table is not None assert info.workload_table is not None
layer_num, num_npus, experts_per_npu = info.workload_table.shape layer_num, num_npus, experts_per_npu = info.workload_table.shape
expert_ids, counts = np.unique(info.placement_table[0], expert_ids, counts = np.unique(info.placement_table[0], return_counts=True)
return_counts=True)
num_redundancy_expert = self.get_redundant_num(num_npus, counts) num_redundancy_expert = self.get_redundant_num(num_npus, counts)
num_original_expert = len(expert_ids) num_original_expert = len(expert_ids)
layer_workloads = self.add_redundant(info.placement_table, layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert)
info.workload_table, max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num)
num_original_expert)
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
info.workload_table, layer_num)
npu_heat_all_origin = sum(max_heat_per_layer_before) npu_heat_all_origin = sum(max_heat_per_layer_before)
num_node = self.safe_exact_divide(num_npus, 8) num_node = self.safe_exact_divide(num_npus, 8)
@@ -700,14 +718,13 @@ class SwiftBalanceEplb(EplbPolicy):
if num_npus < num_redundancy_expert: if num_npus < num_redundancy_expert:
raise ValueError( raise ValueError(
f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})" "The number of NPUs "
f"({num_npus}) must be greater than or equal to the number of redundant experts "
f"({num_redundancy_expert})"
) )
global_deployment: list[list[list[int]]] = [[[] global_deployment: list[list[list[int]]] = [[[] for _ in range(num_npus)] for _ in range(layer_num)]
for _ in range(num_npus)] layer_initial_imbalance = self.calculate_initial_imbalance(info.placement_table, layer_workloads)
for _ in range(layer_num)]
layer_initial_imbalance = self.calculate_initial_imbalance(
info.placement_table, layer_workloads)
max_heat_per_layer_after = np.zeros([layer_num]) max_heat_per_layer_after = np.zeros([layer_num])
sum_num = 0 sum_num = 0
for layer in range(layer_num): for layer in range(layer_num):
@@ -715,8 +732,7 @@ class SwiftBalanceEplb(EplbPolicy):
global_deployment[layer] = info.placement_table[layer] global_deployment[layer] = info.placement_table[layer]
continue continue
ave_workload = self.safe_divide(np.sum(layer_workloads[layer]), ave_workload = self.safe_divide(np.sum(layer_workloads[layer]), num_npus)
num_npus)
rendun_pos: list[list[int]] = [[] for _ in range(num_npus)] rendun_pos: list[list[int]] = [[] for _ in range(num_npus)]
existing_experts = set() existing_experts = set()
@@ -729,30 +745,37 @@ class SwiftBalanceEplb(EplbPolicy):
rendun_pos[device_id].append(index) rendun_pos[device_id].append(index)
result, max_workload, com_between_devices = self.redundant_expert_deployment( result, max_workload, com_between_devices = self.redundant_expert_deployment(
layer_workloads[layer], info.placement_table[layer], layer_workloads[layer],
expert_from_device[layer], num_node, is_node_redundant, info.placement_table[layer],
rendun_pos) expert_from_device[layer],
num_node,
is_node_redundant,
rendun_pos,
)
global_deployment[layer], new_max_workload = self.exchange_experts( global_deployment[layer], new_max_workload = self.exchange_experts(
result, com_between_devices, num_node, num_npus, result,
is_node_redundant, ave_workload, increment, com_between_devices,
num_redundancy_expert, info.placement_table[layer]) num_node,
num_npus,
is_node_redundant,
ave_workload,
increment,
num_redundancy_expert,
info.placement_table[layer],
)
for device_id in range(num_npus): for device_id in range(num_npus):
com_between_devices[device_id] = { com_between_devices[device_id] = {key: value for key, value in com_between_devices[device_id].items()}
key: value
for key, value in com_between_devices[device_id].items()
}
sum_num += self.count_elements(com_between_devices[device_id]) sum_num += self.count_elements(com_between_devices[device_id])
max_heat_per_layer_after[layer] = max( max_heat_per_layer_after[layer] = max(result, key=lambda x: x["total_load"])["total_load"]
result, key=lambda x: x['total_load'])['total_load']
layer_changed_ratio = [] layer_changed_ratio = []
for layer_idx in range(layer_num): for layer_idx in range(layer_num):
layer_changed_ratio.append( layer_changed_ratio.append(
self.safe_divide(max_heat_per_layer_after[layer_idx], self.safe_divide(max_heat_per_layer_after[layer_idx], max_heat_per_layer_before[layer_idx])
max_heat_per_layer_before[layer_idx])) )
per_layer_priority = np.argsort(layer_changed_ratio) per_layer_priority = np.argsort(layer_changed_ratio)
npu_heat_all_after = sum(max_heat_per_layer_after) npu_heat_all_after = sum(max_heat_per_layer_after)
@@ -761,8 +784,6 @@ class SwiftBalanceEplb(EplbPolicy):
if npu_heat_all_after < 0.95 * npu_heat_all_origin: if npu_heat_all_after < 0.95 * npu_heat_all_origin:
change = 1 change = 1
new_global_deployment = self.constraint_expert_local_exchange( new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment)
current_expert_table, global_deployment)
return change, per_layer_priority, np.array( return change, per_layer_priority, np.array(new_global_deployment).tolist()
new_global_deployment).tolist()

View File

@@ -26,9 +26,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
class EplbUpdator: class EplbUpdator:
def __init__(self, eplb_config, loader, eplb_process: EplbProcess, process):
def __init__(self, eplb_config, loader, eplb_process: EplbProcess,
process):
self.eplb_config = eplb_config self.eplb_config = eplb_config
self.init_eplb(self.eplb_config.expert_map_path, process) self.init_eplb(self.eplb_config.expert_map_path, process)
self.eplb_loader = loader self.eplb_loader = loader
@@ -43,9 +41,7 @@ class EplbUpdator:
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
self.device = local_load.device self.device = local_load.device
shape = (self.world_size, *local_load.shape) shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape, self._gather_buffer = torch.empty(shape, dtype=local_load.dtype, device=self.device)
dtype=local_load.dtype,
device=self.device)
def init_eplb(self, expert_map_path, process): def init_eplb(self, expert_map_path, process):
self.rank_id = dist.get_rank() self.rank_id = dist.get_rank()
@@ -72,52 +68,49 @@ class EplbUpdator:
self.process = process self.process = process
logger.info( logger.info(f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
def update_iteration(self): def update_iteration(self):
self.cur_iterations += 1 self.cur_iterations += 1
if self.cur_iterations == (self.expert_heat_collection_interval + \ if self.cur_iterations == (
self.algorithm_execution_interval + self.num_moe_layers): self.expert_heat_collection_interval + self.algorithm_execution_interval + self.num_moe_layers
):
logger.info("Finish expert parallel load balancing.") logger.info("Finish expert parallel load balancing.")
if self.expert_map_record_path is not None: if self.expert_map_record_path is not None:
self.adaptor._export_tensor_to_file( self.adaptor._export_tensor_to_file(self.shared_dict["expert_maps"], self.expert_map_record_path)
self.shared_dict["expert_maps"],
self.expert_map_record_path)
self.adaptor.model.clear_all_moe_loads() self.adaptor.model.clear_all_moe_loads()
self.cur_iterations = 0 self.cur_iterations = 0
def get_update_info_flag(self): def get_update_info_flag(self):
return self.cur_iterations == (self.expert_heat_collection_interval + return self.cur_iterations == (self.expert_heat_collection_interval + self.algorithm_execution_interval - 1)
self.algorithm_execution_interval - 1)
def wakeup_eplb_worker_flag(self): def wakeup_eplb_worker_flag(self):
return self.cur_iterations == (self.expert_heat_collection_interval - return self.cur_iterations == (self.expert_heat_collection_interval - 1)
1)
def update_expert_weight_flag(self): def update_expert_weight_flag(self):
weight_update_counter = self.cur_iterations - ( weight_update_counter = self.cur_iterations - (
self.expert_heat_collection_interval + self.expert_heat_collection_interval + self.algorithm_execution_interval
self.algorithm_execution_interval) )
return (weight_update_counter >= 0 return weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers
and weight_update_counter < self.num_moe_layers)
def wakeup_eplb_worker(self): def wakeup_eplb_worker(self):
self.eplb_process.planner_q.put(1) self.eplb_process.planner_q.put(1)
def forward_before(self): def forward_before(self):
if self.update_expert_weight_flag(): if self.update_expert_weight_flag():
(expert_send_info, expert_recv_info, updated_expert_map, (expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(
log2phy_map, layer_id) = self.update_info_all.pop(0) 0
)
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map)) log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank) self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
updated_expert_map_this_rank = torch.from_numpy( updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map))
numpy.array(updated_expert_map))
self.eplb_loader.generate_expert_d2d_transfer_task( self.eplb_loader.generate_expert_d2d_transfer_task(
expert_send_info, expert_recv_info, expert_send_info,
expert_recv_info,
updated_expert_map_this_rank, updated_expert_map_this_rank,
layer_id + self.adaptor.num_dense_layers) layer_id + self.adaptor.num_dense_layers,
)
# set asynchronous stream for d2d expert weight update # set asynchronous stream for d2d expert weight update
self.reqs = [] self.reqs = []
@@ -133,8 +126,7 @@ class EplbUpdator:
self.compute_and_set_moe_load() self.compute_and_set_moe_load()
self.wakeup_eplb_worker() self.wakeup_eplb_worker()
if self.update_expert_weight_flag( if self.update_expert_weight_flag() and self.expert_map_record_path is None:
) and self.expert_map_record_path is None:
self.eplb_loader.update_expert_map_and_weight(self.reqs) self.eplb_loader.update_expert_map_and_weight(self.reqs)
self.update_iteration() self.update_iteration()
@@ -145,9 +137,7 @@ class EplbUpdator:
moe_load = self._gather_buffer.permute(1, 0, 2) moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu() self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug( logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
if dist.get_rank() == 0: if dist.get_rank() == 0:
self.compute_moe_imbalance(moe_load) self.compute_moe_imbalance(moe_load)
@@ -156,7 +146,6 @@ class EplbUpdator:
return moe_load return moe_load
def compute_moe_imbalance(self, moe_load: torch.Tensor): def compute_moe_imbalance(self, moe_load: torch.Tensor):
self.moe_imbalance_dict.clear() self.moe_imbalance_dict.clear()
layer_card_load = moe_load.sum(dim=-1).cpu().float() layer_card_load = moe_load.sum(dim=-1).cpu().float()
@@ -169,13 +158,11 @@ class EplbUpdator:
moe_load_imbalance = max_load / (mean_load + 1e-6) moe_load_imbalance = max_load / (mean_load + 1e-6)
logger.debug(f"[ModelRunner][MOE_load_stats][Layer {layer_idx}] " logger.debug(f"[ModelRunner][MOE_load_stats][Layer {layer_idx}] PAR={moe_load_imbalance:.4f}")
f"PAR={moe_load_imbalance:.4f}")
self.moe_imbalance_dict[layer_idx] = moe_load_imbalance self.moe_imbalance_dict[layer_idx] = moe_load_imbalance
def summarize_moe_imbalance(self): def summarize_moe_imbalance(self):
values = list(self.moe_imbalance_dict.values()) values = list(self.moe_imbalance_dict.values())
if not values: if not values:
logger.info("[MOE_load_stats] No data available.") logger.info("[MOE_load_stats] No data available.")
@@ -191,11 +178,10 @@ class EplbUpdator:
) )
def warm_up_eplb(self): def warm_up_eplb(self):
self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map() self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map()
self.compute_and_set_moe_load() self.compute_and_set_moe_load()
src_tensor = torch.empty((1, ), device=self.device) src_tensor = torch.empty((1,), device=self.device)
self_rank = dist.get_rank() self_rank = dist.get_rank()
comm_op_list = [] comm_op_list = []

View File

@@ -30,33 +30,30 @@ def get_log2phy_map(self, layer_id):
def get_all_expert_map(self, num_moe_layers): def get_all_expert_map(self, num_moe_layers):
all_loads = [] all_loads = []
num_dense_layers = self.num_dense_layers if hasattr( num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0
self, "num_dense_layers") else 0
for layer_id in range(num_moe_layers): for layer_id in range(num_moe_layers):
load_tensor = self.get_expert_map( load_tensor = self.get_expert_map(layer_id + num_dense_layers) # (num_experts_per_layer,)
layer_id + num_dense_layers) # (num_experts_per_layer,)
all_loads.append(load_tensor) all_loads.append(load_tensor)
return torch.stack(all_loads, dim=0) return torch.stack(all_loads, dim=0)
def get_all_moe_loads(self): def get_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr( num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0
self, "num_dense_layers") else 0
all_moe_loads = torch.stack( all_moe_loads = torch.stack(
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ [
for layer_id in range(self.num_moe_layers)], self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load
dim=0 for layer_id in range(self.num_moe_layers)
],
dim=0,
) )
return all_moe_loads return all_moe_loads
def clear_all_moe_loads(self): def clear_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr( num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0
self, "num_dense_layers") else 0
for layer_id in range(self.num_moe_layers): for layer_id in range(self.num_moe_layers):
self.model.layers[layer_id + self.model.layers[layer_id + num_dense_layers].mlp.experts.clear_moe_load()
num_dense_layers].mlp.experts.clear_moe_load()
def model_register(model, model_config): def model_register(model, model_config):

View File

@@ -18,8 +18,7 @@ import torch
import torch_npu import torch_npu
from vllm.logger import logger from vllm.logger import logger
from .netloader_pg import (destroy_stateless_process_group, from .netloader_pg import destroy_stateless_process_group, stateless_init_process_group
stateless_init_process_group)
class P2PLoad: class P2PLoad:
@@ -56,9 +55,7 @@ class P2PLoad:
- The model if loading is successful, otherwise None. - The model if loading is successful, otherwise None.
""" """
model_device = next(model.parameters()).device model_device = next(model.parameters()).device
logger.info( logger.info(f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
receiver_pg = None receiver_pg = None
loaded_model = None loaded_model = None
try: try:
@@ -67,15 +64,13 @@ class P2PLoad:
port=self.source_port, port=self.source_port,
rank=0, rank=0,
world_size=2, world_size=2,
group_name='netloader', group_name="netloader",
) )
logger.info( logger.info(
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
) )
logger.info( logger.info(f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
logger.info(f"Model device: {model_device}") logger.info(f"Model device: {model_device}")
trans_stream = torch_npu.npu.Stream() trans_stream = torch_npu.npu.Stream()
@@ -84,14 +79,11 @@ class P2PLoad:
if len(param.shape) == 0: if len(param.shape) == 0:
continue continue
receiver_pg.recv([param], 1, 0).wait() receiver_pg.recv([param], 1, 0).wait()
torch.distributed.barrier(group=receiver_pg, torch.distributed.barrier(group=receiver_pg, device_ids=[model_device.index])
device_ids=[model_device.index])
torch_npu.npu.synchronize(trans_stream) torch_npu.npu.synchronize(trans_stream)
logger.info( logger.info(f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
loaded_model = model loaded_model = model
except Exception as e: except Exception as e:
logger.error("Failed to recv model: {}".format(e)) logger.error("Failed to recv model: {}".format(e))
@@ -129,9 +121,7 @@ class P2PSend:
""" """
model_device = next(model.parameters()).device model_device = next(model.parameters()).device
torch.npu.set_device(model_device) torch.npu.set_device(model_device)
logger.info( logger.info(f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
sender_pg = None sender_pg = None
try: try:
sender_pg = stateless_init_process_group( sender_pg = stateless_init_process_group(
@@ -139,14 +129,10 @@ class P2PSend:
port=self.listen_port, port=self.listen_port,
rank=1, rank=1,
world_size=2, world_size=2,
group_name='netloader', group_name="netloader",
)
logger.info(
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
logger.info(
f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
) )
logger.info(f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
logger.info(f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
logger.info(f"Model device: {model_device}") logger.info(f"Model device: {model_device}")
trans_stream = torch_npu.npu.Stream() trans_stream = torch_npu.npu.Stream()
@@ -155,16 +141,12 @@ class P2PSend:
if "aclnn_input_scale" in name: if "aclnn_input_scale" in name:
continue continue
if name in int8_params: if name in int8_params:
sender_pg.send([int8_params[name].to(model_device)], 0, sender_pg.send([int8_params[name].to(model_device)], 0, 0).wait()
0).wait()
else: else:
sender_pg.send([param.contiguous()], 0, 0).wait() sender_pg.send([param.contiguous()], 0, 0).wait()
torch.distributed.barrier(group=sender_pg, torch.distributed.barrier(group=sender_pg, device_ids=[model_device.index])
device_ids=[model_device.index])
torch_npu.npu.synchronize(trans_stream) torch_npu.npu.synchronize(trans_stream)
logger.info( logger.info(f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
finally: finally:
if sender_pg: if sender_pg:
destroy_stateless_process_group(sender_pg) destroy_stateless_process_group(sender_pg)

View File

@@ -17,16 +17,13 @@
import gc import gc
import ipaddress import ipaddress
from datetime import timedelta from datetime import timedelta
from typing import Any, Optional from typing import Any
import torch import torch
import torch_npu import torch_npu
from torch._C._distributed_c10d import (_DEFAULT_PG_TIMEOUT, from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT, _register_process_group, _unregister_process_group
_register_process_group,
_unregister_process_group)
from torch.distributed import ProcessGroup, is_hccl_available from torch.distributed import ProcessGroup, is_hccl_available
from torch.distributed.distributed_c10d import (Backend, BackendConfig, from torch.distributed.distributed_c10d import Backend, BackendConfig, PrefixStore, _world
PrefixStore, _world)
from torch.distributed.rendezvous import rendezvous from torch.distributed.rendezvous import rendezvous
from torch_npu._C._distributed_c10d import ProcessGroupHCCL from torch_npu._C._distributed_c10d import ProcessGroupHCCL
from vllm.logger import logger from vllm.logger import logger
@@ -39,7 +36,7 @@ def stateless_init_process_group(
rank: int, rank: int,
timeout: timedelta = _DEFAULT_PG_TIMEOUT, timeout: timedelta = _DEFAULT_PG_TIMEOUT,
group_name: str = "", group_name: str = "",
pg_options: Optional[Any] = None, pg_options: Any | None = None,
) -> ProcessGroup: ) -> ProcessGroup:
""" """
Initializes a stateless process group. Initializes a stateless process group.
@@ -57,7 +54,8 @@ def stateless_init_process_group(
ProcessGroup: The initialized process group. ProcessGroup: The initialized process group.
Raises: Raises:
RuntimeError: If world_size is not positive, or if rank is not within [0, world_size - 1], or if HCCL is unavailable. RuntimeError: If world_size is not positive, or if rank is not within
[0, world_size - 1], or if HCCL is unavailable.
TypeError: If timeout is not a timedelta type. TypeError: If timeout is not a timedelta type.
ValueError: If group_name already exists. ValueError: If group_name already exists.
""" """
@@ -67,21 +65,18 @@ def stateless_init_process_group(
raise RuntimeError("world_size must be positive") raise RuntimeError("world_size must be positive")
# Check if rank is within [0, world_size - 1] # Check if rank is within [0, world_size - 1]
if not (rank >= 0 and rank <= world_size - 1): if not (rank >= 0 and rank <= world_size - 1):
raise RuntimeError( raise RuntimeError("rank should be a number between 0 and ``world_size``-1")
"rank should be a number between 0 and ``world_size``-1")
# Check if HCCL is available # Check if HCCL is available
if not is_hccl_available(): if not is_hccl_available():
raise RuntimeError("HCCL is not available") raise RuntimeError("HCCL is not available")
# Check if timeout is a timedelta type # Check if timeout is a timedelta type
if not isinstance(timeout, timedelta): if not isinstance(timeout, timedelta):
raise TypeError( raise TypeError(f"Expected timeout argument to be of type datetime.timedelta, got {timeout}")
f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
)
# Check if group_name already exists # Check if group_name already exists
if group_name in _world.pg_names.values(): if group_name in _world.pg_names.values():
raise ValueError( raise ValueError(
f"The specified group name {group_name} has already been " f"The specified group name {group_name} has already been created, please use a different group name"
"created, please use a different group name") )
# Function to check if an IPv6 address is valid # Function to check if an IPv6 address is valid
def is_valid_ipv6_address(address: str) -> bool: def is_valid_ipv6_address(address: str) -> bool:
@@ -101,10 +96,9 @@ def stateless_init_process_group(
# Get initialization method # Get initialization method
init_method = get_tcp_uri(host, port) init_method = get_tcp_uri(host, port)
# Create Backend object # Create Backend object
backend = Backend('hccl') backend = Backend("hccl")
# Use rendezvous function to get store, rank, and world_size # Use rendezvous function to get store, rank, and world_size
store, rank, world_size = next( store, rank, world_size = next(rendezvous(init_method, rank, world_size, timeout=timeout))
rendezvous(init_method, rank, world_size, timeout=timeout))
# Set timeout for store # Set timeout for store
store.set_timeout(timeout) store.set_timeout(timeout)
@@ -125,9 +119,7 @@ def stateless_init_process_group(
pg._set_default_backend(Backend.backend_type_map[backend]) pg._set_default_backend(Backend.backend_type_map[backend])
# Check if pg_options is None or not of type ProcessGroupHCCL.Options # Check if pg_options is None or not of type ProcessGroupHCCL.Options
if pg_options is None or not isinstance( if pg_options is None or not isinstance(pg_options, torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
pg_options,
torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options() pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
# Set attributes for pg_options # Set attributes for pg_options
pg_options.is_high_priority_stream = False pg_options.is_high_priority_stream = False
@@ -135,8 +127,7 @@ def stateless_init_process_group(
pg_options.global_ranks_in_group = [] pg_options.global_ranks_in_group = []
pg_options.group_id = f"{init_method}/{group_name}/" pg_options.group_id = f"{init_method}/{group_name}/"
# Create ProcessGroupHCCL object # Create ProcessGroupHCCL object
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, pg_options)
pg_options)
# Set sequence number for backend_class # Set sequence number for backend_class
backend_class._set_sequence_number_for_group() backend_class._set_sequence_number_for_group()
# Set backend_type # Set backend_type
@@ -176,9 +167,10 @@ def destroy_stateless_process_group(pg: ProcessGroup, manual_gc: bool = False):
_world.pg_group_ranks.pop(pg, None) _world.pg_group_ranks.pop(pg, None)
_world.pg_backend_config.pop(pg, None) _world.pg_backend_config.pop(pg, None)
# Check if pg is in keys of _world.pg_coalesce_state # Check if pg is in keys of _world.pg_coalesce_state
if pg in _world.pg_coalesce_state.keys(): if pg in _world.pg_coalesce_state:
logger.warning("Some coalesced collectives haven't been launched when " logger.warning(
"ProcessGroup is destroyed. They will be cleaned.") "Some coalesced collectives haven't been launched when ProcessGroup is destroyed. They will be cleaned."
)
del _world.pg_coalesce_state[pg] del _world.pg_coalesce_state[pg]
# Unregister the process group # Unregister the process group
_unregister_process_group(pg.group_name) _unregister_process_group(pg.group_name)

View File

@@ -18,7 +18,7 @@ import json
import re import re
import socket import socket
import threading import threading
from typing import List, Optional, Tuple from contextlib import suppress
import torch import torch
from vllm.logger import logger from vllm.logger import logger
@@ -32,8 +32,7 @@ class ElasticClient:
Class for handling the client-side logic of Netloader of models. Class for handling the client-side logic of Netloader of models.
""" """
def __init__(self, sources: list[str], device_id: int, model_path: str, def __init__(self, sources: list[str], device_id: int, model_path: str, tp: int, pp: int):
tp: int, pp: int):
""" """
Initializes the ElasticClient instance. Initializes the ElasticClient instance.
@@ -50,14 +49,14 @@ class ElasticClient:
self.tp = tp self.tp = tp
self.pp = pp self.pp = pp
self.s: Optional[socket.socket] = None self.s: socket.socket | None = None
self.ack: Optional[Tuple[str, int]] = None self.ack: tuple[str, int] | None = None
self.server_addr: Optional[str] = None self.server_addr: str | None = None
self.server_port: Optional[int] = None self.server_port: int | None = None
for source in self.sources: for source in self.sources:
try: try:
ip, port_str = source.split(':') ip, port_str = source.split(":")
port = int(port_str) port = int(port_str)
except Exception as e: except Exception as e:
logger.info(f"IP format error: {source}, detail: {e}") logger.info(f"IP format error: {source}, detail: {e}")
@@ -68,13 +67,9 @@ class ElasticClient:
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
logger.info( logger.info(f"Start connection to server: {self.server_addr}:{self.server_port}")
f"Start connection to server: {self.server_addr}:{self.server_port}"
)
sock.connect((self.server_addr, self.server_port)) sock.connect((self.server_addr, self.server_port))
logger.info( logger.info(f"Finish connection to server: {self.server_addr}:{self.server_port}")
f"Finish connection to server: {self.server_addr}:{self.server_port}"
)
sock.settimeout(60) sock.settimeout(60)
self.s = sock self.s = sock
@@ -83,10 +78,8 @@ class ElasticClient:
except Exception as e: except Exception as e:
logger.error(f"Connect to {source} fails, detail: {e}") logger.error(f"Connect to {source} fails, detail: {e}")
if sock is not None: if sock is not None:
try: with suppress(Exception):
sock.close() sock.close()
except Exception:
pass
self.s = None self.s = None
self.ack = None self.ack = None
self.server_addr = None self.server_addr = None
@@ -120,10 +113,8 @@ class ElasticClient:
""" """
Destructor method to ensure socket is closed. Destructor method to ensure socket is closed.
""" """
try: with suppress(Exception):
self.close() self.close()
except Exception:
pass
def send_str(self, data_str: str) -> None: def send_str(self, data_str: str) -> None:
""" """
@@ -151,8 +142,7 @@ class ElasticClient:
data_str = self.s.recv(buffer_size).decode("utf-8") data_str = self.s.recv(buffer_size).decode("utf-8")
return data_str return data_str
def register(self, device_id: int, model_path: str, tp: int, def register(self, device_id: int, model_path: str, tp: int, pp: int) -> tuple[str, int]:
pp: int) -> Tuple[str, int]:
""" """
Registers the client with the server. Registers the client with the server.
@@ -168,20 +158,13 @@ class ElasticClient:
free_port = find_free_port() free_port = find_free_port()
data = { data = {
"label": "JOIN", "label": "JOIN",
"content": { "content": {"device_id": device_id, "model_path": model_path, "tp": tp, "pp": pp, "port": free_port},
'device_id': device_id,
'model_path': model_path,
'tp': tp,
'pp': pp,
'port': free_port
}
} }
try: try:
self.send_str(json.dumps(data)) self.send_str(json.dumps(data))
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(f"Send data {data} to server fails, detail: {e}")
f"Send data {data} to server fails, detail: {e}")
try: try:
ack_str = self.recv_str() ack_str = self.recv_str()
@@ -191,23 +174,22 @@ class ElasticClient:
try: try:
ack = json.loads(ack_str) ack = json.loads(ack_str)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}")
f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}"
)
logger.info(f"Receive ack: {ack}") logger.info(f"Receive ack: {ack}")
if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack if (
and ack["content"] is not None and "name" in ack["content"]): "label" in ack
and ack["label"] == "JOIN_ACK"
and "content" in ack
and ack["content"] is not None
and "name" in ack["content"]
):
return (ack["content"]["name"], free_port) return (ack["content"]["name"], free_port)
elif ("label" in ack and ack["label"] == 'JOIN_NACK' elif "label" in ack and ack["label"] == "JOIN_NACK" and "content" in ack:
and "content" in ack): raise RuntimeError(f"Receive nack from server, reason: {ack['content']}")
raise RuntimeError(
f"Receive nack from server, reason: {ack['content']}")
else: else:
raise RuntimeError( raise RuntimeError(f"Receive ack {ack} from server does not contain required fields")
f"Receive ack {ack} from server does not contain required fields"
)
class ElasticServer: class ElasticServer:
@@ -215,9 +197,18 @@ class ElasticServer:
Class for handling the server-side logic of Netloader of models. Class for handling the server-side logic of Netloader of models.
""" """
def __init__(self, addr: str, port: int, model, device_id: int, def __init__(
model_path: str, tp: int, pp: int, int8_cache: str, self,
int8_cache_name: Optional[List[str]]): addr: str,
port: int,
model,
device_id: int,
model_path: str,
tp: int,
pp: int,
int8_cache: str,
int8_cache_name: list[str] | None,
):
""" """
Initializes the ElasticServer instance. Initializes the ElasticServer instance.
@@ -246,30 +237,25 @@ class ElasticServer:
self.pp = pp self.pp = pp
self.original_int8 = {} self.original_int8 = {}
int8_pattern = "|".join( int8_pattern = "|".join(map(re.escape, int8_cache_name)) if int8_cache_name is not None else "(?:)"
map(re.escape,
int8_cache_name)) if int8_cache_name is not None else "(?:)"
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if param.dtype == torch.int8: if param.dtype == torch.int8:
if int8_cache == 'hbm': if int8_cache == "hbm":
if int8_cache_name is None or ( if int8_cache_name is None or (
int8_cache_name is not None int8_cache_name is not None and re.search(int8_pattern, name) is not None
and re.search(int8_pattern, name) is not None): ):
try: try:
self.original_int8[name] = param.data.clone( self.original_int8[name] = param.data.clone().detach()
).detach()
except RuntimeError as e: except RuntimeError as e:
logger.error( logger.error(f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}")
f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}"
)
self.original_int8[name] = param.data.cpu() self.original_int8[name] = param.data.cpu()
elif int8_cache == 'dram': elif int8_cache == "dram":
if int8_cache_name is None or ( if int8_cache_name is None or (
int8_cache_name is not None int8_cache_name is not None and re.search(int8_pattern, name) is not None
and re.search(int8_pattern, name) is not None): ):
self.original_int8[name] = param.data.cpu() self.original_int8[name] = param.data.cpu()
elif int8_cache == 'no': elif int8_cache == "no":
pass pass
else: else:
logger.warning( logger.warning(
@@ -277,14 +263,18 @@ class ElasticServer:
) )
logger.info( logger.info(
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}" f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, "
f"model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, "
f"int8 params {list(self.original_int8)} are saved to {int8_cache}"
) )
def __del__(self): def __del__(self):
""" """
Destructor method to ensure socket is closed. Destructor method to ensure socket is closed.
""" """
self.s.close() if self.s is not None:
with suppress(Exception):
self.s.close()
def start(self): def start(self):
""" """
@@ -343,10 +333,7 @@ class ElasticServer:
if not all(k in content for k in required_keys): if not all(k in content for k in required_keys):
return False return False
port = content["port"] port = content["port"]
if not (isinstance(port, int) or return isinstance(port, int) or (isinstance(port, str) and port.isdigit())
(isinstance(port, str) and port.isdigit())):
return False
return True
comm_name = None comm_name = None
if is_valid_data(data): if is_valid_data(data):
@@ -355,36 +342,31 @@ class ElasticServer:
tp = int(data["content"]["tp"]) tp = int(data["content"]["tp"])
pp = int(data["content"]["pp"]) pp = int(data["content"]["pp"])
if int(self.device_id if (
) == device_id and self.model_path == model_path and int( int(self.device_id) == device_id
self.tp) == tp and int(self.pp) == pp: and self.model_path == model_path
and int(self.tp) == tp
and int(self.pp) == pp
):
comm_name = str(addr[0]) + ":" + str(addr[1]) comm_name = str(addr[0]) + ":" + str(addr[1])
ack = {"label": "JOIN_ACK", "content": {"name": comm_name}} ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
else: else:
logger.warning( server_desc = (int(self.device_id), self.model_path, int(self.tp), int(self.pp))
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}" client_desc = (device_id, model_path, tp, pp)
) msg = f"Received data {client_desc} does not consist with this server {server_desc}"
logger.warning(msg)
ack = { ack = {
"label": "label": "JOIN_NACK",
"JOIN_NACK", "content": msg,
"content":
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
} }
else: else:
logger.warning( logger.warning(f"Received data does not contain required fields: {data}")
f"Received data does not contain required fields: {data}") ack = {"label": "JOIN_NACK", "content": f"Received data does not contain required fields: {data}"}
ack = {
"label":
"JOIN_NACK",
"content":
f"Received data does not contain required fields: {data}"
}
try: try:
ack_str = json.dumps(ack).encode("utf-8") ack_str = json.dumps(ack).encode("utf-8")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to convert {ack} to JSON format, details: {e}")
f"Failed to convert {ack} to JSON format, details: {e}")
conn.close() conn.close()
return return
@@ -395,14 +377,10 @@ class ElasticServer:
conn.close() conn.close()
return return
if ack["content"] and isinstance(ack["content"], if ack["content"] and isinstance(ack["content"], dict) and "name" in ack["content"]:
dict) and 'name' in ack["content"]:
try: try:
p2psend = P2PSend(self.addr, data["content"]["port"], p2psend = P2PSend(self.addr, data["content"]["port"], ack["content"]["name"])
ack["content"]["name"])
p2psend.send(self.model, self.original_int8) p2psend.send(self.model, self.original_int8)
except Exception as e: except Exception as e:
logger.error( logger.error(f"P2PSend Failed to send model to {self.addr}, details: {e}")
f"P2PSend Failed to send model to {self.addr}, details: {e}"
)
conn.close() conn.close()

View File

@@ -48,36 +48,27 @@ def elastic_load(
# Filter sources for the current device # Filter sources for the current device
sources_this_device = [] sources_this_device = []
for s in sources: for s in sources:
if isinstance( if isinstance(s, dict) and "device_id" in s and s["device_id"] == device_id and isinstance(s["sources"], list):
s, dict
) and "device_id" in s and s["device_id"] == device_id and isinstance(
s["sources"], list):
sources_this_device += s["sources"] sources_this_device += s["sources"]
if len(sources_this_device) == 0: if len(sources_this_device) == 0:
return None return None
try: try:
# Initialize the interaction layer with the ElasticClient # Initialize the interaction layer with the ElasticClient
with ElasticClient(sources_this_device, device_id, model_path, tp, with ElasticClient(sources_this_device, device_id, model_path, tp, pp) as client_interaction_layer:
pp) as client_interaction_layer:
if client_interaction_layer.s is None or client_interaction_layer.server_addr is None: if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
raise RuntimeError( raise RuntimeError("Failed to initialize ElasticClient: socket or server_addr is None")
"Failed to initialize ElasticClient: socket or server_addr is None"
)
ack = client_interaction_layer.ack ack = client_interaction_layer.ack
if ack is None: if ack is None:
raise RuntimeError("ElasticClient.register did not return ack") raise RuntimeError("ElasticClient.register did not return ack")
t0 = time.perf_counter() t0 = time.perf_counter()
elastic_loader = P2PLoad(ack[0], elastic_loader = P2PLoad(ack[0], client_interaction_layer.server_addr, ack[1])
client_interaction_layer.server_addr,
ack[1])
model_loaded = elastic_loader.load(model=model) model_loaded = elastic_loader.load(model=model)
if model_loaded is None: if model_loaded is None:
logger.error("Failed to load model") logger.error("Failed to load model")
return None return None
logger.info("Finish elastic load (duration: {}s)".format( logger.info("Finish elastic load (duration: {}s)".format(time.perf_counter() - t0))
time.perf_counter() - t0))
return model_loaded return model_loaded
except Exception as e: except Exception as e:
logger.info(f"elastic_load error: {e}") logger.info(f"elastic_load error: {e}")

View File

@@ -18,7 +18,6 @@ import gc
import json import json
import time import time
from copy import deepcopy from copy import deepcopy
from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@@ -27,8 +26,7 @@ from vllm.logger import logger
from vllm.model_executor.model_loader import register_model_loader from vllm.model_executor.model_loader import register_model_loader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import initialize_model, process_weights_after_loading
initialize_model, process_weights_after_loading)
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
from .interaction.elastic import ElasticServer from .interaction.elastic import ElasticServer
@@ -41,12 +39,13 @@ class ModelNetLoaderElastic(BaseModelLoader):
""" """
A model loader that uses elastic loading for loading weights. A model loader that uses elastic loading for loading weights.
""" """
source: Optional[List[dict]]
model_path: Optional[str] source: list[dict] | None
listen_port: Optional[int] model_path: str | None
listen_port: int | None
int8_cache: str int8_cache: str
int8_cache_name: Optional[List[str]] int8_cache_name: list[str] | None
output_prefix: Optional[str] output_prefix: str | None
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
""" """
@@ -63,18 +62,15 @@ class ModelNetLoaderElastic(BaseModelLoader):
extra = load_config.model_loader_extra_config extra = load_config.model_loader_extra_config
if extra and "CONFIG_FILE" in extra: if extra and "CONFIG_FILE" in extra:
try: try:
logger.info( logger.info(f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ...")
f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ..." with open(extra["CONFIG_FILE"]) as f:
)
with open(extra["CONFIG_FILE"], 'r') as f:
config = json.load(f) config = json.load(f)
except FileNotFoundError: except FileNotFoundError:
logger.error("CONFIG_FILE not found") logger.error("CONFIG_FILE not found")
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error("CONFIG_FILE is not a valid JSON file") logger.error("CONFIG_FILE is not a valid JSON file")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Unexpected error while reading CONFIG_FILE: {e}")
f"Unexpected error while reading CONFIG_FILE: {e}")
if config is None and extra: if config is None and extra:
logger.info("Reading configs in model_loader_extra_config ...") logger.info("Reading configs in model_loader_extra_config ...")
@@ -82,19 +78,30 @@ class ModelNetLoaderElastic(BaseModelLoader):
config = config or {} config = config or {}
for key, attr, checker, caster, default in [ for key, attr, checker, caster, default in [
("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v, ("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v, None),
None), ("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v, None),
("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v, (
None), "LISTEN_PORT",
("LISTEN_PORT", "listen_port", lambda v: isinstance(v, int) or "listen_port",
(isinstance(v, str) and v.isdigit()), lambda v: int(v), None), lambda v: isinstance(v, int) or (isinstance(v, str) and v.isdigit()),
("INT8_CACHE", "int8_cache", lambda v: isinstance(v, str) and v. lambda v: int(v),
lower() in ['hbm', 'dram', 'no'], lambda v: v.lower(), 'no'), None,
("INT8_CACHE_NAME", "int8_cache_name", ),
lambda v: isinstance(v, list), lambda v: v, None), (
("OUTPUT_PREFIX", "output_prefix", "INT8_CACHE",
lambda v: isinstance(v, str) and is_valid_path_prefix(v), "int8_cache",
lambda v: v, None), lambda v: isinstance(v, str) and v.lower() in ["hbm", "dram", "no"],
lambda v: v.lower(),
"no",
),
("INT8_CACHE_NAME", "int8_cache_name", lambda v: isinstance(v, list), lambda v: v, None),
(
"OUTPUT_PREFIX",
"output_prefix",
lambda v: isinstance(v, str) and is_valid_path_prefix(v),
lambda v: v,
None,
),
]: ]:
v = config.get(key, default) v = config.get(key, default)
if not checker(v): if not checker(v):
@@ -116,8 +123,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
self.output_prefix, self.output_prefix,
) )
def load_model(self, vllm_config: VllmConfig, def load_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> nn.Module:
model_config: ModelConfig) -> nn.Module:
""" """
Loads the model using the specified configuration. Loads the model using the specified configuration.
@@ -140,15 +146,18 @@ class ModelNetLoaderElastic(BaseModelLoader):
device_id = torch.distributed.get_rank() device_id = torch.distributed.get_rank()
if (self.source is None or not isinstance(self.source, list) if (
or device_id not in [ self.source is None
one_device["device_id"] for one_device in self.source if or not isinstance(self.source, list)
isinstance(one_device, dict) and "device_id" in one_device or device_id
]): not in [
logger.warning( one_device["device_id"]
"Did not get valid source info, use DefaultModelLoader") for one_device in self.source
model, need_process_weights_after_loading = self.revert_to_default( if isinstance(one_device, dict) and "device_id" in one_device
model_config, vllm_config, device_config) ]
):
logger.warning("Did not get valid source info, use DefaultModelLoader")
model, need_process_weights_after_loading = self.revert_to_default(model_config, vllm_config, device_config)
else: else:
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
@@ -158,8 +167,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = initialize_model(vllm_config=vllm_config, model = initialize_model(vllm_config=vllm_config, model_config=model_config)
model_config=model_config)
start_elastic_load = time.perf_counter() start_elastic_load = time.perf_counter()
model = elastic_load( model = elastic_load(
@@ -171,43 +179,39 @@ class ModelNetLoaderElastic(BaseModelLoader):
pp=parallel_config.pipeline_parallel_size, pp=parallel_config.pipeline_parallel_size,
) )
end_elastic_load = time.perf_counter() end_elastic_load = time.perf_counter()
logger.info( logger.info(f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}")
f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}"
)
need_process_weights_after_loading = True need_process_weights_after_loading = True
if model is None: if model is None:
logger.warning( logger.warning("Netloader elastic loading fails, use load format DefaultModelLoader")
"Netloader elastic loading fails, use load format DefaultModelLoader"
)
vllm_config = vllm_config_backup vllm_config = vllm_config_backup
model_config = model_config_backup model_config = model_config_backup
del model del model
gc.collect() gc.collect()
if device_config.device_type == 'npu': if device_config.device_type == "npu":
logger.info("Empty NPU cache") logger.info("Empty NPU cache")
torch.npu.empty_cache() torch.npu.empty_cache()
elif device_config.device_type == 'cuda': elif device_config.device_type == "cuda":
logger.info("Empty CUDA cache") logger.info("Empty CUDA cache")
torch.cuda.empty_cache() torch.cuda.empty_cache()
model, need_process_weights_after_loading = self.revert_to_default( model, need_process_weights_after_loading = self.revert_to_default(
model_config, vllm_config, device_config) model_config, vllm_config, device_config
)
start_elastic_server = time.perf_counter() start_elastic_server = time.perf_counter()
# start elastic server # start elastic server
if model is not None and ( if model is not None and (
(self.listen_port and self.listen_port in range(1024, 65535)) or (self.listen_port and self.listen_port in range(1024, 65535)) or (self.listen_port is None)
(self.listen_port is None)): ):
from vllm.utils.network_utils import get_ip from vllm.utils.network_utils import get_ip
driver_ip = get_ip() driver_ip = get_ip()
if driver_ip == '0.0.0.0': if driver_ip == "0.0.0.0":
logger.error( logger.error("Driver IP is not set, skip to start Netloader server")
"Driver IP is not set, skip to start Netloader server")
else: else:
if self.listen_port is None: if self.listen_port is None:
self.listen_port = find_free_port() self.listen_port = find_free_port()
@@ -220,21 +224,14 @@ class ModelNetLoaderElastic(BaseModelLoader):
if self.output_prefix is not None: if self.output_prefix is not None:
try: try:
with open(self.output_prefix + str(device_id) + '.txt', with open(self.output_prefix + str(device_id) + ".txt", "w") as file:
'w') as file:
file.write(f"{driver_ip}:{self.listen_port}") file.write(f"{driver_ip}:{self.listen_port}")
logger.info( logger.info(f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}")
f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}"
)
except FileNotFoundError: except FileNotFoundError:
logger.error( logger.error(f"File path {self.output_prefix + str(device_id)} does not exist.")
f"File path {self.output_prefix + str(device_id)} does not exist."
)
except PermissionError: except PermissionError:
logger.error( logger.error(f"No permission to write to file {self.output_prefix + str(device_id)}.")
f"No permission to write to file {self.output_prefix + str(device_id)}." except OSError as e:
)
except IOError as e:
logger.error( logger.error(
f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}" f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}"
) )
@@ -242,31 +239,30 @@ class ModelNetLoaderElastic(BaseModelLoader):
logger.error(f"Unknown error: {e}") logger.error(f"Unknown error: {e}")
try: try:
assert isinstance( assert isinstance(self.listen_port, int), f"listen port should be int but get {self.listen_port}"
self.listen_port, int
), f"listen port should be int but get {self.listen_port}"
elastic_server = ElasticServer( elastic_server = ElasticServer(
driver_ip, self.listen_port, model, device_id, driver_ip,
self.model_path, parallel_config.tensor_parallel_size, self.listen_port,
model,
device_id,
self.model_path,
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size,
self.int8_cache, self.int8_cache_name) self.int8_cache,
self.int8_cache_name,
)
elastic_server.start() elastic_server.start()
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to start Netloader server for rank: {device_id}, details: {e}")
f"Failed to start Netloader server for rank: {device_id}, details: {e}"
)
else: else:
logger.info("Skip to start Netloader server") logger.info("Skip to start Netloader server")
end_elastic_server = time.perf_counter() end_elastic_server = time.perf_counter()
logger.info( logger.info(f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}")
f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}"
)
if need_process_weights_after_loading: if need_process_weights_after_loading:
process_weights_after_loading(model, model_config, process_weights_after_loading(model, model_config, torch.device(device_config.device))
torch.device(device_config.device))
if model is None: if model is None:
logger.error("NetLoader elastic loads model fails") logger.error("NetLoader elastic loads model fails")
@@ -274,8 +270,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
return model.eval() return model.eval()
def revert_to_default(self, model_config, vllm_config, def revert_to_default(self, model_config, vllm_config, device_config) -> tuple[nn.Module, bool]:
device_config) -> Tuple[nn.Module, bool]:
""" """
Reverts to the default model loading logic when elastic loading fails or is not applicable. Reverts to the default model loading logic when elastic loading fails or is not applicable.
@@ -300,19 +295,15 @@ class ModelNetLoaderElastic(BaseModelLoader):
default_model_loader = DefaultModelLoader(self.load_config) default_model_loader = DefaultModelLoader(self.load_config)
if model_config.quantization is None: if model_config.quantization is None:
model = default_model_loader.load_model(vllm_config=vllm_config, model = default_model_loader.load_model(vllm_config=vllm_config, model_config=model_config)
model_config=model_config)
need_process_weights_after_loading = False need_process_weights_after_loading = False
else: else:
logger.warning( logger.warning("Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading ")
"Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading "
)
need_process_weights_after_loading = True need_process_weights_after_loading = True
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = initialize_model(vllm_config=vllm_config, model = initialize_model(vllm_config=vllm_config, model_config=model_config)
model_config=model_config)
default_model_loader.load_weights(model, model_config) default_model_loader.load_weights(model, model_config)
model = model.eval() model = model.eval()
@@ -321,6 +312,5 @@ class ModelNetLoaderElastic(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
pass pass
def load_weights(self, model: nn.Module, def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
model_config: ModelConfig) -> None:
pass pass

View File

@@ -29,7 +29,7 @@ def find_free_port():
- A free port number. - A free port number.
""" """
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) s.bind(("", 0))
return s.getsockname()[1] return s.getsockname()[1]
@@ -47,20 +47,17 @@ def is_valid_path_prefix(path_prefix):
return False return False
if re.search(r'[<>:"|?*]', path_prefix): if re.search(r'[<>:"|?*]', path_prefix):
logger.warning( logger.warning(f"The path prefix {path_prefix} contains illegal characters.")
f'The path prefix {path_prefix} contains illegal characters.')
return False return False
if path_prefix.startswith('/') or path_prefix.startswith('\\'): if path_prefix.startswith("/") or path_prefix.startswith("\\"):
if not os.path.exists(os.path.dirname(path_prefix)): if not os.path.exists(os.path.dirname(path_prefix)):
logger.warning( logger.warning(f"The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.")
f'The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.'
)
return False return False
else: else:
if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))): if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))):
logger.warning( logger.warning(
f'The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist.' f"The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist."
) )
return False return False
return True return True

View File

@@ -23,9 +23,8 @@ import vllm_ascend.patch.platform.patch_sched_yield # noqa
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv( if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
"EXPERT_MAP_RECORD", "false") == "true":
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
if envs.VLLM_ASCEND_BALANCE_SCHEDULING and vllm_version_is('0.14.0'): if envs.VLLM_ASCEND_BALANCE_SCHEDULING and vllm_version_is("0.14.0"):
import vllm_ascend.patch.platform.patch_balance_schedule # noqa import vllm_ascend.patch.platform.patch_balance_schedule # noqa

View File

@@ -7,17 +7,14 @@ import torch.distributed as dist
import vllm import vllm
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.base import \ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
KVConnectorMetadata
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.transformers_utils.config import \ from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
maybe_register_config_serialize_by_value
from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import (SchedulingPolicy, from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
create_request_queue)
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
@@ -30,7 +27,6 @@ logger = init_logger(__name__)
class BalanceScheduler(Scheduler): class BalanceScheduler(Scheduler):
def __init__( def __init__(
self, self,
vllm_config, vllm_config,
@@ -41,9 +37,15 @@ class BalanceScheduler(Scheduler):
include_finished_set: bool = False, include_finished_set: bool = False,
log_stats: bool = False, log_stats: bool = False,
) -> None: ) -> None:
super().__init__(vllm_config, kv_cache_config, super().__init__(
structured_output_manager, block_size, mm_registry, vllm_config,
include_finished_set, log_stats) kv_cache_config,
structured_output_manager,
block_size,
mm_registry,
include_finished_set,
log_stats,
)
# Balance scheduling. # Balance scheduling.
self.balance_queue = [ self.balance_queue = [
torch.tensor([0], dtype=torch.int, device="cpu") torch.tensor([0], dtype=torch.int, device="cpu")
@@ -51,9 +53,7 @@ class BalanceScheduler(Scheduler):
] ]
def balance_gather(self, dp_group): def balance_gather(self, dp_group):
running_tensor = torch.tensor([len(self.running)], running_tensor = torch.tensor([len(self.running)], dtype=torch.int, device="cpu")
dtype=torch.int,
device="cpu")
dist.all_gather(self.balance_queue, running_tensor, group=dp_group) dist.all_gather(self.balance_queue, running_tensor, group=dp_group)
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
@@ -89,33 +89,32 @@ class BalanceScheduler(Scheduler):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
if (request.num_output_placeholders > 0 if (
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1). request.num_output_placeholders > 0
# Since output placeholders are also included in the computed tokens # This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
# count, we subtract (num_output_placeholders - 1) to remove any draft # Since output placeholders are also included in the computed tokens
# tokens, so that we can be sure no further steps are needed even if # count, we subtract (num_output_placeholders - 1) to remove any draft
# they are all rejected. # tokens, so that we can be sure no further steps are needed even if
and request.num_computed_tokens + 2 - # they are all rejected.
request.num_output_placeholders and request.num_computed_tokens + 2 - request.num_output_placeholders
>= request.num_prompt_tokens + request.max_tokens): >= request.num_prompt_tokens + request.max_tokens
):
# Async scheduling: Avoid scheduling an extra step when we are sure that # Async scheduling: Avoid scheduling an extra step when we are sure that
# the previous step has reached request.max_tokens. We don't schedule # the previous step has reached request.max_tokens. We don't schedule
# partial draft tokens since this prevents uniform decode optimizations. # partial draft tokens since this prevents uniform decode optimizations.
req_index += 1 req_index += 1
continue continue
num_new_tokens = (request.num_tokens_with_spec + num_new_tokens = (
request.num_output_placeholders - request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
request.num_computed_tokens) )
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len. # Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding. # This is necessary when using spec decoding.
num_new_tokens = min( num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs. # Schedule encoder inputs.
encoder_inputs_to_schedule = None encoder_inputs_to_schedule = None
@@ -174,20 +173,17 @@ class BalanceScheduler(Scheduler):
self.running.remove(preempted_req) self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs: if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req) scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[ token_budget += num_scheduled_tokens[preempted_req.request_id]
preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id) req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id) num_scheduled_tokens.pop(preempted_req.request_id)
scheduled_spec_decode_tokens.pop( scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
preempted_req.request_id, None) preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
preempted_req.request_id, None)
if preempted_encoder_inputs: if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted # Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step. # request had encoder inputs scheduled in this step.
num_embeds_to_restore = sum( num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i) preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs
for i in preempted_encoder_inputs) )
encoder_compute_budget += num_embeds_to_restore encoder_compute_budget += num_embeds_to_restore
req_index -= 1 req_index -= 1
else: else:
@@ -212,23 +208,20 @@ class BalanceScheduler(Scheduler):
# Speculative decode related. # Speculative decode related.
if request.spec_token_ids: if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens + num_scheduled_spec_tokens = (
request.num_computed_tokens - num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
request.num_tokens - )
request.num_output_placeholders)
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. # Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:] del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = ( scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
request.spec_token_ids)
# New spec tokens will be set in `update_draft_token_ids` before the # New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable. # next step when applicable.
request.spec_token_ids = [] request.spec_token_ids = []
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
encoder_inputs_to_schedule)
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -243,8 +236,10 @@ class BalanceScheduler(Scheduler):
scheduled_loras: set[int] = set() scheduled_loras: set[int] = set()
if self.lora_config: if self.lora_config:
scheduled_loras = set( scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs req.lora_request.lora_int_id
if req.lora_request and req.lora_request.lora_int_id > 0) for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0
)
assert len(scheduled_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be # Use a temporary RequestQueue to collect requests that need to be
@@ -257,9 +252,7 @@ class BalanceScheduler(Scheduler):
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
balance_flag = (max( balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs
t.item()
for t in self.balance_queue) == self.max_num_running_reqs)
if balance_flag: if balance_flag:
break break
@@ -292,9 +285,14 @@ class BalanceScheduler(Scheduler):
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if (self.lora_config and request.lora_request and if (
(len(scheduled_loras) == self.lora_config.max_loras and self.lora_config
request.lora_request.lora_int_id not in scheduled_loras)): and request.lora_request
and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras
)
):
# Scheduling would exceed max_loras, skip. # Scheduling would exceed max_loras, skip.
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
@@ -306,14 +304,15 @@ class BalanceScheduler(Scheduler):
# Get already-cached tokens. # Get already-cached tokens.
if request.num_computed_tokens == 0: if request.num_computed_tokens == 0:
# Get locally-cached tokens. # Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = ( new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
self.kv_cache_manager.get_computed_blocks(request)) request
)
# Get externally-cached tokens if using a KVConnector. # Get externally-cached tokens if using a KVConnector.
if self.connector is not None: if self.connector is not None:
ext_tokens, load_kv_async = ( ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
self.connector.get_num_new_matched_tokens( request, num_new_local_computed_tokens
request, num_new_local_computed_tokens)) )
if ext_tokens is None: if ext_tokens is None:
# The request cannot be scheduled because # The request cannot be scheduled because
@@ -327,8 +326,7 @@ class BalanceScheduler(Scheduler):
num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens + num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
num_external_computed_tokens)
else: else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0 # KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed. # after async KV recvs are completed.
@@ -356,8 +354,7 @@ class BalanceScheduler(Scheduler):
# chunked prefill has to be enabled explicitly to allow # chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked # pooling requests to be chunked
if (not self.scheduler_config.enable_chunked_prefill if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
and num_new_tokens > token_budget):
# If chunked_prefill is disabled, # If chunked_prefill is disabled,
# we can stop the scheduling here. # we can stop the scheduling here.
break break
@@ -388,9 +385,7 @@ class BalanceScheduler(Scheduler):
# extra block gets allocated which # extra block gets allocated which
# creates a mismatch between the number # creates a mismatch between the number
# of local and remote blocks. # of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
== 0 else
self.num_lookahead_tokens)
# Determine if we need to allocate cross-attention blocks. # Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs: if self.is_encoder_decoder and request.has_encoder_inputs:
@@ -398,8 +393,7 @@ class BalanceScheduler(Scheduler):
# always padded to the maximum length. If we support other # always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we # encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed. # want to only allocate what is needed.
num_encoder_tokens = ( num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
self.scheduler_config.max_num_encoder_input_tokens)
else: else:
num_encoder_tokens = 0 num_encoder_tokens = 0
@@ -442,20 +436,17 @@ class BalanceScheduler(Scheduler):
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
scheduled_timestamp)
if request.status == RequestStatus.WAITING: if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request) scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED: elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request) scheduled_resumed_reqs.append(request)
else: else:
raise RuntimeError( raise RuntimeError(f"Invalid request status: {request.status}")
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = ( req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
self.kv_cache_manager.get_blocks(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
@@ -465,8 +456,7 @@ class BalanceScheduler(Scheduler):
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
encoder_inputs_to_schedule)
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
@@ -476,8 +466,7 @@ class BalanceScheduler(Scheduler):
for i in external_load_encoder_input: for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None: if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc( self.ec_connector.update_state_after_alloc(request, i)
request, i)
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests) self.waiting.prepend_requests(skipped_waiting_requests)
@@ -491,20 +480,15 @@ class BalanceScheduler(Scheduler):
# Since some requests in the RUNNING queue may not be scheduled in # Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than # this step, the total number of scheduled requests can be smaller than
# len(self.running). # len(self.running).
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
scheduled_running_reqs) <= len(self.running)
# Get the longest common prefix among all requests in the running queue. # Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention. # This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len( num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
self.kv_cache_config.kv_cache_groups) with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
with record_function_or_nullcontext(
"schedule: get_num_common_prefix_blocks"):
if self.running: if self.running:
any_request = self.running[0] any_request = self.running[0]
num_common_prefix_blocks = ( num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id))
# Construct the scheduler output. # Construct the scheduler output.
if self.use_v2_model_runner: if self.use_v2_model_runner:
@@ -515,17 +499,16 @@ class BalanceScheduler(Scheduler):
req, req,
req_to_new_blocks[req.request_id].get_block_ids(), req_to_new_blocks[req.request_id].get_block_ids(),
req._all_token_ids, req._all_token_ids,
) for req in scheduled_new_reqs )
for req in scheduled_new_reqs
] ]
else: else:
new_reqs_data = [ new_reqs_data = [
NewRequestData.from_request( NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs for req in scheduled_new_reqs
] ]
with record_function_or_nullcontext( with record_function_or_nullcontext("schedule: make_cached_request_data"):
"schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data( cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs, scheduled_running_reqs,
scheduled_resumed_reqs, scheduled_resumed_reqs,
@@ -546,15 +529,13 @@ class BalanceScheduler(Scheduler):
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs, scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks, num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids={req.request_id preempted_req_ids={req.request_id for req in preempted_reqs},
for req in preempted_reqs},
# finished_req_ids is an existing state in the scheduler, # finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step. # instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between # It contains the request IDs that are finished in between
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager. free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
get_freed_mm_hashes(),
) )
# NOTE(Kuntai): this function is designed for multiple purposes: # NOTE(Kuntai): this function is designed for multiple purposes:
@@ -562,14 +543,12 @@ class BalanceScheduler(Scheduler):
# 2. Wrap up all the KV cache load / save ops into an opaque object # 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector # 3. Clear the internal states of the connector
if self.connector is not None: if self.connector is not None:
meta: KVConnectorMetadata = self.connector.build_connector_meta( meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output)
scheduler_output)
scheduler_output.kv_connector_metadata = meta scheduler_output.kv_connector_metadata = meta
# Build the connector meta for ECConnector # Build the connector meta for ECConnector
if self.ec_connector is not None: if self.ec_connector is not None:
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output)
scheduler_output)
scheduler_output.ec_connector_metadata = ec_meta scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"): with record_function_or_nullcontext("schedule: update_after_schedule"):
@@ -578,7 +557,6 @@ class BalanceScheduler(Scheduler):
class BalanceDPEngineCoreProc(DPEngineCoreProc): class BalanceDPEngineCoreProc(DPEngineCoreProc):
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case.""" """Core busy loop of the EngineCore for data parallel case."""
@@ -602,23 +580,23 @@ class BalanceDPEngineCoreProc(DPEngineCoreProc):
self.execute_dummy_batch() self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs. # 3) All-reduce operation to determine global unfinished reqs.
self.engines_running = self._has_global_unfinished_reqs( self.engines_running = self._has_global_unfinished_reqs(local_unfinished_reqs)
local_unfinished_reqs)
self.scheduler.balance_gather(self.dp_group) self.scheduler.balance_gather(self.dp_group)
if not self.engines_running: if not self.engines_running:
if self.dp_rank == 0 or not self.has_coordinator: if self.dp_rank == 0 or not self.has_coordinator:
# Notify client that we are pausing the loop. # Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.", logger.debug("Wave %d finished, pausing engine loop.", self.current_wave)
self.current_wave)
# In the coordinator case, dp rank 0 sends updates to the # In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank # coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process. # sends the update to its colocated front-end process.
client_index = -1 if self.has_coordinator else 0 client_index = -1 if self.has_coordinator else 0
self.output_queue.put_nowait(( self.output_queue.put_nowait(
client_index, (
EngineCoreOutputs(wave_complete=self.current_wave), client_index,
)) EngineCoreOutputs(wave_complete=self.current_wave),
)
)
# Increment wave count and reset step counter. # Increment wave count and reset step counter.
self.current_wave += 1 self.current_wave += 1
self.step_counter = 0 self.step_counter = 0

View File

@@ -1,21 +1,21 @@
import vllm.distributed.ec_transfer.ec_connector.example_connector import vllm.distributed.ec_transfer.ec_connector.example_connector
from safetensors.torch import load_file from safetensors.torch import load_file
from vllm.distributed.ec_transfer.ec_connector.example_connector import ( from vllm.distributed.ec_transfer.ec_connector.example_connector import ECConnectorMetadata, ECExampleConnector
ECConnectorMetadata, ECExampleConnector)
from vllm.logger import logger from vllm.logger import logger
class AscendECExampleConnector(ECExampleConnector): class AscendECExampleConnector(ECExampleConnector):
def start_load_caches(self, encoder_cache, **kwargs) -> None: def start_load_caches(self, encoder_cache, **kwargs) -> None:
metadata: ECConnectorMetadata = self._get_connector_metadata() metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECConnectorMetadata) assert isinstance(metadata, ECConnectorMetadata)
assert encoder_cache is not None assert encoder_cache is not None
if metadata is None: if metadata is None:
logger.warning(( logger.warning(
"In connector.start_load_caches, ", (
"but the connector metadata is None", "In connector.start_load_caches, ",
)) "but the connector metadata is None",
)
)
return return
# Load the EC for each mm data # Load the EC for each mm data
for mm_data in metadata.mm_datas: for mm_data in metadata.mm_datas:
@@ -24,8 +24,7 @@ class AscendECExampleConnector(ECExampleConnector):
filename = self._generate_filename_debug(mm_data.mm_hash) filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = load_file(filename)["ec_cache"].npu() ec_cache = load_file(filename)["ec_cache"].npu()
encoder_cache[mm_data.mm_hash] = ec_cache encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
mm_data.mm_hash)
vllm.distributed.ec_transfer.ec_connector.example_connector.ECExampleConnector = AscendECExampleConnector vllm.distributed.ec_transfer.ec_connector.example_connector.ECExampleConnector = AscendECExampleConnector

View File

@@ -38,7 +38,8 @@ def verify_and_update_config(cls, vllm_config) -> None:
block_size=1, block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config), num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(), head_size=model_config.get_head_size(),
dtype=kv_cache_dtype).page_size_bytes dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls( model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture, model_config.architecture,
@@ -58,23 +59,20 @@ def verify_and_update_config(cls, vllm_config) -> None:
# block size to multiple of 16, so let's suggest a value # block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible # that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead). # with mamba layers, use FlashInfer instead).
attn_block_size = block_alignment_bytes * cdiv( attn_block_size = block_alignment_bytes * cdiv(mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
# override attention block size if either (a) the # override attention block size if either (a) the
# user has not set it or (b) the user has set it # user has not set it or (b) the user has set it
# too small. # too small.
if (cache_config.block_size is None if cache_config.block_size is None or cache_config.block_size < attn_block_size:
or cache_config.block_size < attn_block_size):
cache_config.block_size = attn_block_size cache_config.block_size = attn_block_size
logger.info( logger.info(
"Setting attention block size to %d tokens " "Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.",
"to ensure that attention page size is >= mamba page size.", attn_block_size,
attn_block_size) )
# compute new attention page size # compute new attention page size
attn_page_size = \ attn_page_size = cache_config.block_size * attn_page_size_1_token
cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size assert attn_page_size >= mamba_page_size
@@ -83,15 +81,15 @@ def verify_and_update_config(cls, vllm_config) -> None:
return return
# pad mamba page size to exactly match attention # pad mamba page size to exactly match attention
if (cache_config.mamba_page_size_padded is None if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size:
or cache_config.mamba_page_size_padded != attn_page_size): cache_config.mamba_page_size_padded = attn_page_size
cache_config.mamba_page_size_padded = (attn_page_size) mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size
mamba_padding_pct = 100 * (attn_page_size -
mamba_page_size) / mamba_page_size
logger.info( logger.info(
"Padding mamba page size by %.2f%% to ensure " "Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are " "that mamba page size and attention page size are "
"exactly equal.", mamba_padding_pct) "exactly equal.",
mamba_padding_pct,
)
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config

View File

@@ -7,19 +7,20 @@ from multiprocessing.synchronize import Lock as LockType
import vllm.v1.executor.multiproc_executor import vllm.v1.executor.multiproc_executor
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.device_communicators.shm_broadcast import (Handle, from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
MessageQueue) from vllm.utils.network_utils import get_distributed_init_method, get_loopback_ip, get_open_port
from vllm.utils.network_utils import (get_distributed_init_method,
get_loopback_ip, get_open_port)
from vllm.utils.system_utils import get_mp_context from vllm.utils.system_utils import get_mp_context
from vllm.v1.executor.abstract import FailureCallback from vllm.v1.executor.abstract import FailureCallback
from vllm.v1.executor.multiproc_executor import ( from vllm.v1.executor.multiproc_executor import (
FutureWrapper, MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc, FutureWrapper,
set_multiprocessing_worker_envs) MultiprocExecutor,
UnreadyWorkerProcHandle,
WorkerProc,
set_multiprocessing_worker_envs,
)
class AscendMultiprocExecutor(MultiprocExecutor): class AscendMultiprocExecutor(MultiprocExecutor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up # Call self.shutdown at exit to clean up
# and ensure workers will be terminated. # and ensure workers will be terminated.
@@ -32,7 +33,8 @@ class AscendMultiprocExecutor(MultiprocExecutor):
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, ( assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
f"global world_size ({self.parallel_config.world_size}) must be " f"global world_size ({self.parallel_config.world_size}) must be "
f"divisible by nnodes_within_dp " f"divisible by nnodes_within_dp "
f"({self.parallel_config.nnodes_within_dp}). ") f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size self.local_world_size = self.parallel_config.local_world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size
@@ -41,7 +43,8 @@ class AscendMultiprocExecutor(MultiprocExecutor):
f"world_size ({self.world_size}) must be equal to the " f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}) x prefill_context" f"_parallel_size ({pp_parallel_size}) x prefill_context"
f"_parallel_size ({pcp_parallel_size}). ") f"_parallel_size ({pcp_parallel_size}). "
)
# Set multiprocessing envs # Set multiprocessing envs
set_multiprocessing_worker_envs() set_multiprocessing_worker_envs()
@@ -49,8 +52,7 @@ class AscendMultiprocExecutor(MultiprocExecutor):
# Multiprocessing-based executor does not support multi-node setting. # Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address # Since it only works for single node, we can use the loopback address
# get_loopback_ip() for communication. # get_loopback_ip() for communication.
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(get_loopback_ip(), get_open_port())
get_loopback_ip(), get_open_port())
self.rpc_broadcast_mq: MessageQueue | None = None self.rpc_broadcast_mq: MessageQueue | None = None
scheduler_output_handle: Handle | None = None scheduler_output_handle: Handle | None = None
# Initialize worker and set up message queues for SchedulerOutputs # Initialize worker and set up message queues for SchedulerOutputs
@@ -72,8 +74,7 @@ class AscendMultiprocExecutor(MultiprocExecutor):
unready_workers: list[UnreadyWorkerProcHandle] = [] unready_workers: list[UnreadyWorkerProcHandle] = []
success = False success = False
try: try:
global_start_rank = (self.local_world_size * global_start_rank = self.local_world_size * self.parallel_config.node_rank_within_dp
self.parallel_config.node_rank_within_dp)
for local_rank in range(self.local_world_size): for local_rank in range(self.local_world_size):
global_rank = global_start_rank + local_rank global_rank = global_start_rank + local_rank
unready_workers.append( unready_workers.append(
@@ -84,7 +85,8 @@ class AscendMultiprocExecutor(MultiprocExecutor):
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle, input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock, shared_worker_lock=shared_worker_lock,
)) )
)
# Workers must be created before wait_for_ready to avoid # Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync. # deadlock, since worker.init_device() does a device sync.
@@ -101,13 +103,11 @@ class AscendMultiprocExecutor(MultiprocExecutor):
if self.parallel_config.node_rank_within_dp == 0: if self.parallel_config.node_rank_within_dp == 0:
for rank in range(self.world_size): for rank in range(self.world_size):
if rank < self.local_world_size: if rank < self.local_world_size:
local_message_queue = self.workers[ local_message_queue = self.workers[rank].worker_response_mq
rank].worker_response_mq
assert local_message_queue is not None assert local_message_queue is not None
self.response_mqs.append(local_message_queue) self.response_mqs.append(local_message_queue)
else: else:
remote_message_queue = self.workers[ remote_message_queue = self.workers[0].peer_worker_response_mqs[rank]
0].peer_worker_response_mqs[rank]
assert remote_message_queue is not None assert remote_message_queue is not None
self.response_mqs.append(remote_message_queue) self.response_mqs.append(remote_message_queue)
@@ -128,8 +128,7 @@ class AscendMultiprocExecutor(MultiprocExecutor):
for uw in unready_workers: for uw in unready_workers:
if uw.death_writer is not None: if uw.death_writer is not None:
uw.death_writer.close() uw.death_writer.close()
self._ensure_worker_termination( self._ensure_worker_termination([uw.proc for uw in unready_workers])
[uw.proc for uw in unready_workers])
self.futures_queue = deque[tuple[FutureWrapper, Callable]]() self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
@@ -137,7 +136,6 @@ class AscendMultiprocExecutor(MultiprocExecutor):
class AscendWorkerProc(WorkerProc): class AscendWorkerProc(WorkerProc):
@staticmethod @staticmethod
def make_worker_process( def make_worker_process(
vllm_config: VllmConfig, vllm_config: VllmConfig,

View File

@@ -3,11 +3,10 @@ import sys
import vllm.distributed.utils import vllm.distributed.utils
from vllm.platforms import CpuArchEnum, Platform from vllm.platforms import CpuArchEnum, Platform
is_arm = (Platform.get_cpu_architecture() == CpuArchEnum.ARM) is_arm = Platform.get_cpu_architecture() == CpuArchEnum.ARM
USE_SCHED_YIELD = ( USE_SCHED_YIELD = (
((sys.version_info[:3] >= (3, 11, 1)) or (sys.version_info[:3] >= (3, 11, 1)) or (sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8)
(sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8)) ) and not is_arm
and not is_arm)
vllm.distributed.utils.USE_SCHED_YIELD = USE_SCHED_YIELD vllm.distributed.utils.USE_SCHED_YIELD = USE_SCHED_YIELD