### What this PR does / why we need it?
| File Path |
| :--- |
| ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` |
| ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` |
| ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` |
| ` vllm_ascend/eplb/core/eplb_utils.py` |
| ` vllm_ascend/eplb/core/eplb_worker.py` |
| ` vllm_ascend/eplb/core/policy/policy_abstract.py` |
| ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` |
| ` vllm_ascend/eplb/core/policy/policy_factory.py` |
| ` vllm_ascend/eplb/core/policy/policy_flashlb.py` |
| ` vllm_ascend/eplb/core/policy/policy_random.py` |
| ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` |
| ` vllm_ascend/eplb/eplb_updator.py` |
| ` vllm_ascend/eplb/utils.py` |
| ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` |
| ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` |
| ` vllm_ascend/model_loader/netloader/interaction/elastic.py` |
| ` vllm_ascend/model_loader/netloader/load.py` |
| ` vllm_ascend/model_loader/netloader/netloader.py` |
| ` vllm_ascend/model_loader/netloader/utils.py` |
| ` vllm_ascend/patch/platform/__init__.py` |
| ` vllm_ascend/patch/platform/patch_balance_schedule.py` |
| ` vllm_ascend/patch/platform/patch_ec_connector.py` |
| ` vllm_ascend/patch/platform/patch_mamba_config.py` |
| ` vllm_ascend/patch/platform/patch_multiproc_executor.py` |
| ` vllm_ascend/patch/platform/patch_sched_yield.py` |
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -61,10 +61,6 @@ exclude = [
|
|||||||
"vllm_ascend/distributed/kv_transfer/utils/**",
|
"vllm_ascend/distributed/kv_transfer/utils/**",
|
||||||
"vllm_ascend/kv_offload/**",
|
"vllm_ascend/kv_offload/**",
|
||||||
"vllm_ascend/lora/**",
|
"vllm_ascend/lora/**",
|
||||||
# (6)
|
|
||||||
"vllm_ascend/eplb/**",
|
|
||||||
"vllm_ascend/model_loader/netloader/**",
|
|
||||||
"vllm_ascend/patch/**",
|
|
||||||
# (7)
|
# (7)
|
||||||
"vllm_ascend/quantization/**",
|
"vllm_ascend/quantization/**",
|
||||||
"vllm_ascend/sample/*.py",
|
"vllm_ascend/sample/*.py",
|
||||||
@@ -92,6 +88,7 @@ exclude = [
|
|||||||
"vllm_ascend/distributed/parallel_state.py",
|
"vllm_ascend/distributed/parallel_state.py",
|
||||||
"vllm_ascend/distributed/utils.py",
|
"vllm_ascend/distributed/utils.py",
|
||||||
"vllm_ascend/xlite/*.py",
|
"vllm_ascend/xlite/*.py",
|
||||||
|
"vllm_ascend/patch/worker/patch_*.py",
|
||||||
# (11)
|
# (11)
|
||||||
"vllm_ascend/ops/fused_moe/**",
|
"vllm_ascend/ops/fused_moe/**",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ from abc import abstractmethod
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class EplbAdaptor():
|
class EplbAdaptor:
|
||||||
|
|
||||||
def __init__(self, **args):
|
def __init__(self, **args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -29,12 +28,9 @@ class EplbAdaptor():
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_update_expert_map(self, layer_id: Any,
|
def do_update_expert_map(self, layer_id: Any, updated_expert_map: Any) -> Any:
|
||||||
updated_expert_map: Any) -> Any:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_update_expert_weight(self, layer_id: Any,
|
def do_update_expert_weight(self, layer_id: Any, local_expert_to_replace: Any, buffer_tensor_id: Any) -> Any:
|
||||||
local_expert_to_replace: Any,
|
|
||||||
buffer_tensor_id: Any) -> Any:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
|
|||||||
|
|
||||||
|
|
||||||
class VllmEplbAdaptor(EplbAdaptor):
|
class VllmEplbAdaptor(EplbAdaptor):
|
||||||
|
|
||||||
def __init__(self, model, **args):
|
def __init__(self, model, **args):
|
||||||
super().__init__(**args)
|
super().__init__(**args)
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -36,33 +35,37 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0)
|
self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0)
|
||||||
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
|
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
|
||||||
|
|
||||||
for i in range(self.num_dense_layers,
|
for i in range(self.num_dense_layers, self.model.config.num_hidden_layers):
|
||||||
self.model.config.num_hidden_layers):
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = self.model.model.layers[
|
||||||
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \
|
i
|
||||||
self.model.model.layers[i].mlp.experts.w13_weight_list
|
].mlp.experts.w13_weight_list
|
||||||
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = self.model.model.layers[
|
||||||
self.model.model.layers[i].mlp.experts.w2_weight_list
|
i
|
||||||
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \
|
].mlp.experts.w2_weight_list
|
||||||
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = (
|
||||||
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
|
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
|
||||||
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
|
)
|
||||||
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = (
|
||||||
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
|
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
|
||||||
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
|
)
|
||||||
|
# TODO: init self.expert_weight_names depending on different model types.
|
||||||
|
# Only deepseek v3 w8a8 and qwen3-moe is supported here
|
||||||
if self.model.quant_config is not None:
|
if self.model.quant_config is not None:
|
||||||
self.expert_weight_names = [
|
self.expert_weight_names = [
|
||||||
"w13_weight_list", "w2_weight_list",
|
"w13_weight_list",
|
||||||
"w13_weight_scale_fp32_list", "w13_weight_offset",
|
"w2_weight_list",
|
||||||
"w2_weight_scale_list", "w2_weight_offset"
|
"w13_weight_scale_fp32_list",
|
||||||
|
"w13_weight_offset",
|
||||||
|
"w2_weight_scale_list",
|
||||||
|
"w2_weight_offset",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||||
|
|
||||||
self.expert_map_per_layer_cpu = dict(
|
self.expert_map_per_layer_cpu = dict() # copy of expert map on CPU to avoid device synchronize frequently
|
||||||
) # copy of expert map on CPU to avoid device synchronize frequently
|
|
||||||
|
|
||||||
num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts
|
num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts
|
||||||
self.buffer_tensor_list: list[list[Any]] = [
|
self.buffer_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)]
|
||||||
[] for _ in range(num_buffer_tensor)
|
|
||||||
]
|
|
||||||
self.init_buffer_tensor(num_buffer_tensor)
|
self.init_buffer_tensor(num_buffer_tensor)
|
||||||
|
|
||||||
self.expert_param_per_layer = dict()
|
self.expert_param_per_layer = dict()
|
||||||
@@ -70,18 +73,15 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
|
|
||||||
self.log2phy_map_per_layer = dict()
|
self.log2phy_map_per_layer = dict()
|
||||||
for layer_idx in range(self.num_moe_layers):
|
for layer_idx in range(self.num_moe_layers):
|
||||||
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
|
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = self.model.get_log2phy_map(
|
||||||
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
|
self.num_dense_layers + layer_idx
|
||||||
|
)
|
||||||
|
|
||||||
def init_buffer_tensor(self, num_buffer_tensor):
|
def init_buffer_tensor(self, num_buffer_tensor):
|
||||||
for buffer_id in range(num_buffer_tensor):
|
for buffer_id in range(num_buffer_tensor):
|
||||||
for name in self.expert_weight_names:
|
for name in self.expert_weight_names:
|
||||||
complete_name = "model.layers." + str(
|
complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name
|
||||||
self.num_dense_layers) + ".mlp.experts." + name
|
if name in ["w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w2_weight_scale_list"]:
|
||||||
if name in [
|
|
||||||
"w13_weight_list", "w2_weight_list",
|
|
||||||
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
|
|
||||||
]:
|
|
||||||
expert_tensor = self.param_dict[complete_name][0]
|
expert_tensor = self.param_dict[complete_name][0]
|
||||||
expert_tensor = expert_tensor.clone()
|
expert_tensor = expert_tensor.clone()
|
||||||
else:
|
else:
|
||||||
@@ -99,19 +99,20 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
per_expert_param = list()
|
per_expert_param = list()
|
||||||
for name in self.expert_weight_names:
|
for name in self.expert_weight_names:
|
||||||
if name in [
|
if name in [
|
||||||
"w13_weight_list", "w2_weight_list",
|
"w13_weight_list",
|
||||||
"w13_weight_scale_fp32_list",
|
"w2_weight_list",
|
||||||
"w2_weight_scale_list"
|
"w13_weight_scale_fp32_list",
|
||||||
|
"w2_weight_scale_list",
|
||||||
]:
|
]:
|
||||||
per_expert_param.append(
|
per_expert_param.append(
|
||||||
self.param_dict["model.layers." + str(layer_idx) +
|
self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][local_expert_id]
|
||||||
".mlp.experts." +
|
)
|
||||||
name][local_expert_id])
|
|
||||||
else:
|
else:
|
||||||
per_expert_param.append(
|
per_expert_param.append(
|
||||||
self.param_dict["model.layers." + str(layer_idx) +
|
self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][0].data[
|
||||||
".mlp.experts." +
|
local_expert_id
|
||||||
name][0].data[local_expert_id])
|
]
|
||||||
|
)
|
||||||
self.expert_param_per_layer[layer_idx].append(per_expert_param)
|
self.expert_param_per_layer[layer_idx].append(per_expert_param)
|
||||||
|
|
||||||
def get_rank_expert_workload(self) -> torch.Tensor:
|
def get_rank_expert_workload(self) -> torch.Tensor:
|
||||||
@@ -123,26 +124,18 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
num_local_experts = expert_maps.max() + 1
|
num_local_experts = expert_maps.max() + 1
|
||||||
|
|
||||||
expert_maps_list = expert_maps.tolist()
|
expert_maps_list = expert_maps.tolist()
|
||||||
record: dict[str, Any] = {
|
record: dict[str, Any] = {"moe_layer_count": len(expert_maps_list), "layer_list": []}
|
||||||
"moe_layer_count": len(expert_maps_list),
|
|
||||||
"layer_list": []
|
|
||||||
}
|
|
||||||
|
|
||||||
for layer_idx, layer_data in enumerate(expert_maps_list):
|
for layer_idx, layer_data in enumerate(expert_maps_list):
|
||||||
layer_record: dict[str, Any] = {
|
layer_record: dict[str, Any] = {
|
||||||
"layer_id": layer_idx,
|
"layer_id": layer_idx,
|
||||||
"device_count": len(layer_data),
|
"device_count": len(layer_data),
|
||||||
"device_list": []
|
"device_list": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for device_idx, experts in enumerate(layer_data):
|
for device_idx, experts in enumerate(layer_data):
|
||||||
placement = [
|
placement = [experts.index(i) for i in range(num_local_experts)]
|
||||||
experts.index(i) for i in range(num_local_experts)
|
device_record = {"device_id": device_idx, "device_expert": placement}
|
||||||
]
|
|
||||||
device_record = {
|
|
||||||
"device_id": device_idx,
|
|
||||||
"device_expert": placement
|
|
||||||
}
|
|
||||||
layer_record["device_list"].append(device_record)
|
layer_record["device_list"].append(device_record)
|
||||||
|
|
||||||
record["layer_list"].append(layer_record)
|
record["layer_list"].append(layer_record)
|
||||||
@@ -153,11 +146,10 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
def do_update_expert_map(self, layer_id, updated_expert_map):
|
def do_update_expert_map(self, layer_id, updated_expert_map):
|
||||||
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
|
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
|
||||||
|
|
||||||
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
|
def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id):
|
||||||
buffer_tensor_id):
|
|
||||||
for expert_tensor, buffer_tensor in zip(
|
for expert_tensor, buffer_tensor in zip(
|
||||||
self.expert_param_per_layer[layer_id][local_expert_to_replace],
|
self.expert_param_per_layer[layer_id][local_expert_to_replace], self.buffer_tensor_list[buffer_tensor_id]
|
||||||
self.buffer_tensor_list[buffer_tensor_id]):
|
):
|
||||||
expert_tensor.copy_(buffer_tensor)
|
expert_tensor.copy_(buffer_tensor)
|
||||||
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
|
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
|
||||||
|
|
||||||
@@ -168,10 +160,8 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
def get_global_expert_map(self):
|
def get_global_expert_map(self):
|
||||||
all_layer_global_expert_map = []
|
all_layer_global_expert_map = []
|
||||||
for layer_id in range(self.num_moe_layers):
|
for layer_id in range(self.num_moe_layers):
|
||||||
map_cpu = self.model.model.layers[
|
map_cpu = self.model.model.layers[self.num_dense_layers + layer_id].mlp.experts.global_expert_map.cpu()
|
||||||
self.num_dense_layers + layer_id].mlp.experts.global_expert_map.cpu()
|
|
||||||
all_layer_global_expert_map.append(map_cpu)
|
all_layer_global_expert_map.append(map_cpu)
|
||||||
self.expert_map_per_layer_cpu[self.num_dense_layers +
|
self.expert_map_per_layer_cpu[self.num_dense_layers + layer_id] = map_cpu[self.rank_id]
|
||||||
layer_id] = map_cpu[self.rank_id]
|
|
||||||
|
|
||||||
return torch.stack(all_layer_global_expert_map)
|
return torch.stack(all_layer_global_expert_map)
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ class ExpertWeightUpdateState(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class D2DExpertWeightLoader:
|
class D2DExpertWeightLoader:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.comm_op_list = None
|
self.comm_op_list = None
|
||||||
self.updated_expert_map = None
|
self.updated_expert_map = None
|
||||||
@@ -40,14 +39,10 @@ class D2DExpertWeightLoader:
|
|||||||
def set_adator(self, eplb_adaptor):
|
def set_adator(self, eplb_adaptor):
|
||||||
self.eplb_adaptor = eplb_adaptor
|
self.eplb_adaptor = eplb_adaptor
|
||||||
|
|
||||||
def generate_expert_d2d_transfer_task(self, expert_send_info,
|
def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info, updated_expert_map, layer_id):
|
||||||
expert_recv_info, updated_expert_map,
|
|
||||||
layer_id):
|
|
||||||
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
|
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
|
||||||
if self.state != ExpertWeightUpdateState.WAITING:
|
if self.state != ExpertWeightUpdateState.WAITING:
|
||||||
logger.warning_once(
|
logger.warning_once("current d2d weight update tasks are on-going, cannot accept new weight update task")
|
||||||
"current d2d weight update tasks are on-going, cannot accept new weight update task"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.updated_expert_map = updated_expert_map
|
self.updated_expert_map = updated_expert_map
|
||||||
@@ -56,25 +51,16 @@ class D2DExpertWeightLoader:
|
|||||||
self.comm_op_list = []
|
self.comm_op_list = []
|
||||||
for send_info in expert_send_info:
|
for send_info in expert_send_info:
|
||||||
dst_rank, global_expert_id_to_send = send_info
|
dst_rank, global_expert_id_to_send = send_info
|
||||||
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[
|
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item()
|
||||||
layer_id][global_expert_id_to_send].item()
|
for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]:
|
||||||
for src_tensor in self.eplb_adaptor.expert_param_per_layer[
|
self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
|
||||||
layer_id][local_expert_id]:
|
|
||||||
self.comm_op_list.append(
|
|
||||||
dist.P2POp(dist.isend, src_tensor, dst_rank))
|
|
||||||
|
|
||||||
buffer_tensor_id = 0
|
for buffer_tensor_id, recv_info in enumerate(expert_recv_info):
|
||||||
for recv_info in expert_recv_info:
|
|
||||||
recv_rank, global_expert_id_to_recv = recv_info
|
recv_rank, global_expert_id_to_recv = recv_info
|
||||||
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[
|
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]:
|
||||||
buffer_tensor_id]:
|
self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
|
||||||
self.comm_op_list.append(
|
local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item()
|
||||||
dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
|
self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id))
|
||||||
local_expert_to_replace = self.updated_expert_map[
|
|
||||||
global_expert_id_to_recv].item()
|
|
||||||
self.recv_expert_list.append(
|
|
||||||
(local_expert_to_replace, buffer_tensor_id))
|
|
||||||
buffer_tensor_id += 1
|
|
||||||
|
|
||||||
self.state = ExpertWeightUpdateState.READY
|
self.state = ExpertWeightUpdateState.READY
|
||||||
|
|
||||||
@@ -106,23 +92,18 @@ class D2DExpertWeightLoader:
|
|||||||
self.comm_op_list = None
|
self.comm_op_list = None
|
||||||
|
|
||||||
# update expert_map
|
# update expert_map
|
||||||
self.eplb_adaptor.do_update_expert_map(self.layer_id,
|
self.eplb_adaptor.do_update_expert_map(self.layer_id, self.updated_expert_map)
|
||||||
self.updated_expert_map)
|
|
||||||
|
|
||||||
# update log2phy_map
|
# update log2phy_map
|
||||||
self.eplb_adaptor.do_update_log2phy_map(self.layer_id,
|
self.eplb_adaptor.do_update_log2phy_map(self.layer_id, self.updated_log2phy_map)
|
||||||
self.updated_log2phy_map)
|
|
||||||
|
|
||||||
# update expert weight
|
# update expert weight
|
||||||
buffer_tensor_id = 0
|
buffer_tensor_id = 0
|
||||||
for recv_expert_info in self.recv_expert_list:
|
for recv_expert_info in self.recv_expert_list:
|
||||||
local_expert_to_replace, buffer_tensor_id = recv_expert_info
|
local_expert_to_replace, buffer_tensor_id = recv_expert_info
|
||||||
self.eplb_adaptor.do_update_expert_weight(self.layer_id,
|
self.eplb_adaptor.do_update_expert_weight(self.layer_id, local_expert_to_replace, buffer_tensor_id)
|
||||||
local_expert_to_replace,
|
|
||||||
buffer_tensor_id)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"[EPLB] finished update expert weight for layer: {self.layer_id}")
|
||||||
f"[EPLB] finished update expert weight for layer: {self.layer_id}")
|
|
||||||
|
|
||||||
self.recv_expert_list = []
|
self.recv_expert_list = []
|
||||||
self.updated_expert_map = None
|
self.updated_expert_map = None
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from vllm.logger import logger
|
|||||||
|
|
||||||
|
|
||||||
def expert_file_to_tensor(expert_map_path, layer_id):
|
def expert_file_to_tensor(expert_map_path, layer_id):
|
||||||
with open(expert_map_path, "r") as f:
|
with open(expert_map_path) as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
physical_count = 0
|
physical_count = 0
|
||||||
device_data = []
|
device_data = []
|
||||||
@@ -61,38 +61,32 @@ def init_eplb_config(eplb_config, layer_id, moe_config):
|
|||||||
eplb_enable = eplb_config.dynamic_eplb
|
eplb_enable = eplb_config.dynamic_eplb
|
||||||
n_redundant = eplb_config.num_redundant_experts if eplb_enable else 0
|
n_redundant = eplb_config.num_redundant_experts if eplb_enable else 0
|
||||||
if expert_map_path:
|
if expert_map_path:
|
||||||
if not (os.path.exists(expert_map_path)
|
if not (os.path.exists(expert_map_path) and os.access(expert_map_path, os.R_OK)):
|
||||||
and os.access(expert_map_path, os.R_OK)):
|
|
||||||
raise ValueError("Invalid EPLB path")
|
raise ValueError("Invalid EPLB path")
|
||||||
eplb_enable = True
|
eplb_enable = True
|
||||||
global_placement, physical_count = expert_file_to_tensor(
|
global_placement, physical_count = expert_file_to_tensor(expert_map_path, layer_id)
|
||||||
expert_map_path, layer_id)
|
|
||||||
if physical_count is not None:
|
if physical_count is not None:
|
||||||
n_redundant = physical_count - n_experts
|
n_redundant = physical_count - n_experts
|
||||||
if not moe_config.supports_eplb:
|
if not moe_config.supports_eplb:
|
||||||
raise ValueError(
|
raise ValueError("Eplb supports only w8a8_dynamic quantization.")
|
||||||
"Eplb supports only w8a8_dynamic quantization.")
|
|
||||||
else:
|
else:
|
||||||
eplb_enable = False
|
eplb_enable = False
|
||||||
|
|
||||||
if global_placement is None:
|
if global_placement is None:
|
||||||
global_placement = generate_global_placement(n_experts, ep_size,
|
global_placement = generate_global_placement(n_experts, ep_size, n_redundant)
|
||||||
n_redundant)
|
|
||||||
|
|
||||||
if ep_size == 1:
|
if ep_size == 1:
|
||||||
assert not eplb_enable, "EPLB must used in expert parallelism."
|
assert not eplb_enable, "EPLB must used in expert parallelism."
|
||||||
return None, None, None, n_redundant
|
return None, None, None, n_redundant
|
||||||
global_expert_map = []
|
global_expert_map = []
|
||||||
for rankid in range(ep_size):
|
for rankid in range(ep_size):
|
||||||
expert_map = torch.full((n_experts, ), -1, dtype=torch.int32)
|
expert_map = torch.full((n_experts,), -1, dtype=torch.int32)
|
||||||
local_placement = global_placement[rankid]
|
local_placement = global_placement[rankid]
|
||||||
expert_map[local_placement] = torch.arange(local_placement.shape[0],
|
expert_map[local_placement] = torch.arange(local_placement.shape[0], dtype=torch.int32)
|
||||||
dtype=torch.int32)
|
|
||||||
global_expert_map.append(expert_map)
|
global_expert_map.append(expert_map)
|
||||||
if rankid == moe_config.ep_rank:
|
if rankid == moe_config.ep_rank:
|
||||||
local_expert_map = expert_map.npu()
|
local_expert_map = expert_map.npu()
|
||||||
log2phy = generate_log2phy_map(
|
log2phy = generate_log2phy_map(global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None
|
||||||
global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None
|
|
||||||
|
|
||||||
return torch.stack(global_expert_map), local_expert_map, log2phy, n_redundant
|
return torch.stack(global_expert_map), local_expert_map, log2phy, n_redundant
|
||||||
|
|
||||||
@@ -106,13 +100,15 @@ def generate_log2phy_map(global_expert_map, ep_rank):
|
|||||||
if val != -1:
|
if val != -1:
|
||||||
log2phy_map[idx].append(val + rankid * valid_count)
|
log2phy_map[idx].append(val + rankid * valid_count)
|
||||||
|
|
||||||
for key in log2phy_map.keys():
|
for key in log2phy_map:
|
||||||
num_of_duplications = len(log2phy_map[key])
|
num_of_duplications = len(log2phy_map[key])
|
||||||
log2phy_map[key] = log2phy_map[key][ep_rank % num_of_duplications]
|
log2phy_map[key] = log2phy_map[key][ep_rank % num_of_duplications]
|
||||||
|
|
||||||
log2phy_map = torch.scatter(
|
log2phy_map = torch.scatter(
|
||||||
torch.zeros(len(log2phy_map.keys()), dtype=torch.int32), 0,
|
torch.zeros(len(log2phy_map), dtype=torch.int32),
|
||||||
torch.tensor(list(log2phy_map.keys()), dtype=torch.int64),
|
0,
|
||||||
torch.tensor(list(log2phy_map.values()), dtype=torch.int32))
|
torch.tensor(list(log2phy_map), dtype=torch.int64),
|
||||||
|
torch.tensor(list(log2phy_map.values()), dtype=torch.int32),
|
||||||
|
)
|
||||||
|
|
||||||
return log2phy_map
|
return log2phy_map
|
||||||
|
|||||||
@@ -17,23 +17,18 @@
|
|||||||
from multiprocessing import Process, Queue
|
from multiprocessing import Process, Queue
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import networkx as nx # type: ignore
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map
|
from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map
|
||||||
from vllm_ascend.eplb.core.policy.policy_factory import (DynamicConfig,
|
from vllm_ascend.eplb.core.policy.policy_factory import DynamicConfig, PolicyFactory
|
||||||
PolicyFactory)
|
|
||||||
|
|
||||||
|
|
||||||
class EplbWorker:
|
class EplbWorker:
|
||||||
|
|
||||||
def __init__(self, shared_dict, policy_type, enable_d2d: bool = True):
|
def __init__(self, shared_dict, policy_type, enable_d2d: bool = True):
|
||||||
self.policy_type = policy_type
|
self.policy_type = policy_type
|
||||||
self.policy = PolicyFactory.generate_policy(policy_type,
|
self.policy = PolicyFactory.generate_policy(policy_type, DynamicConfig())
|
||||||
DynamicConfig())
|
|
||||||
self.shared_dict = shared_dict
|
self.shared_dict = shared_dict
|
||||||
self.old_expert_maps = None
|
self.old_expert_maps = None
|
||||||
self.enable_d2d = enable_d2d
|
self.enable_d2d = enable_d2d
|
||||||
@@ -62,10 +57,8 @@ class EplbWorker:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Get the updated expert table based on the workload information
|
# Get the updated expert table based on the workload information
|
||||||
old_placement = self.global2local(self.old_expert_maps,
|
old_placement = self.global2local(self.old_expert_maps, self.num_local_experts)
|
||||||
self.num_local_experts)
|
_, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement)
|
||||||
_, _, new_placement = self.calculate_rebalance_experts(
|
|
||||||
load_info, old_placement)
|
|
||||||
|
|
||||||
if not torch.is_tensor(new_placement):
|
if not torch.is_tensor(new_placement):
|
||||||
new_placement = torch.tensor(new_placement)
|
new_placement = torch.tensor(new_placement)
|
||||||
@@ -73,8 +66,7 @@ class EplbWorker:
|
|||||||
new_expert_maps = self.local2global(new_placement)
|
new_expert_maps = self.local2global(new_placement)
|
||||||
self.update_expert_map(new_expert_maps)
|
self.update_expert_map(new_expert_maps)
|
||||||
|
|
||||||
update_info = self.compose_expert_update_info_greedy(
|
update_info = self.compose_expert_update_info_greedy(new_expert_maps, self.old_expert_maps)
|
||||||
new_expert_maps, self.old_expert_maps)
|
|
||||||
self.old_expert_maps = new_expert_maps
|
self.old_expert_maps = new_expert_maps
|
||||||
logger.info("EPLB Process compute complete")
|
logger.info("EPLB Process compute complete")
|
||||||
|
|
||||||
@@ -88,11 +80,8 @@ class EplbWorker:
|
|||||||
|
|
||||||
for layer_id in range(num_layers):
|
for layer_id in range(num_layers):
|
||||||
# check if any logical expert is not placed on any rank
|
# check if any logical expert is not placed on any rank
|
||||||
if torch.unique(new_placement[layer_id]).numel() < torch.unique(
|
if torch.unique(new_placement[layer_id]).numel() < torch.unique(old_placement[layer_id]).numel():
|
||||||
old_placement[layer_id]).numel():
|
logger.error(f"There exists expert not placed on any rank in layer {layer_id}")
|
||||||
logger.error(
|
|
||||||
f"There exists expert not placed on any rank in layer {layer_id}"
|
|
||||||
)
|
|
||||||
new_placement[layer_id] = old_placement[layer_id]
|
new_placement[layer_id] = old_placement[layer_id]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -101,28 +90,26 @@ class EplbWorker:
|
|||||||
old_placement_check = old_placement[layer_id][rank_id]
|
old_placement_check = old_placement[layer_id][rank_id]
|
||||||
|
|
||||||
# check if same logical experts are placed on the same NPU
|
# check if same logical experts are placed on the same NPU
|
||||||
if new_placement_check.numel() != torch.unique(
|
if new_placement_check.numel() != torch.unique(new_placement_check).numel():
|
||||||
new_placement_check).numel():
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
|
"Replicated experts are placed on the same NPU; expert placement on "
|
||||||
|
f"layer {layer_id}, rank {rank_id} is invalid"
|
||||||
)
|
)
|
||||||
new_placement[layer_id] = old_placement[layer_id]
|
new_placement[layer_id] = old_placement[layer_id]
|
||||||
break
|
break
|
||||||
|
|
||||||
# check if there is any experts movement inside one NPU
|
# check if there is any experts movement inside one NPU
|
||||||
expert_not_move = torch.isin(new_placement_check,
|
expert_not_move = torch.isin(new_placement_check, old_placement_check)
|
||||||
old_placement_check)
|
if not torch.equal(new_placement_check[expert_not_move], old_placement_check[expert_not_move]):
|
||||||
if not torch.equal(new_placement_check[expert_not_move],
|
|
||||||
old_placement_check[expert_not_move]):
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
|
"There exists expert movement inside NPU; expert placement on "
|
||||||
|
f"layer {layer_id}, rank {rank_id} is invalid"
|
||||||
)
|
)
|
||||||
new_placement[layer_id] = old_placement[layer_id]
|
new_placement[layer_id] = old_placement[layer_id]
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases
|
# TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases
|
||||||
def compose_expert_update_info_greedy(self, updated_expert_maps,
|
def compose_expert_update_info_greedy(self, updated_expert_maps, current_expert_maps):
|
||||||
current_expert_maps):
|
|
||||||
num_layers = current_expert_maps.shape[0]
|
num_layers = current_expert_maps.shape[0]
|
||||||
for layer_id in range(num_layers):
|
for layer_id in range(num_layers):
|
||||||
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
|
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
|
||||||
@@ -132,19 +119,23 @@ class EplbWorker:
|
|||||||
expert_recv_info_this_layer: dict[Any, Any] = {}
|
expert_recv_info_this_layer: dict[Any, Any] = {}
|
||||||
|
|
||||||
# Guard Clause: if there is no expert weight update, avoid subsequent processing
|
# Guard Clause: if there is no expert weight update, avoid subsequent processing
|
||||||
if torch.equal(updated_expert_maps_this_layer,
|
if torch.equal(updated_expert_maps_this_layer, current_expert_maps_this_layer):
|
||||||
current_expert_maps_this_layer):
|
yield (
|
||||||
yield (expert_send_info_this_layer,
|
expert_send_info_this_layer,
|
||||||
expert_recv_info_this_layer,
|
expert_recv_info_this_layer,
|
||||||
updated_expert_maps_this_layer, layer_id)
|
updated_expert_maps_this_layer,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Parse expert_ids each rank needs to receive from other ranks
|
# Parse expert_ids each rank needs to receive from other ranks
|
||||||
dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \
|
dst_rank_indices, experts_to_recv = torch.where(
|
||||||
& (updated_expert_maps_this_layer != -1))
|
(current_expert_maps_this_layer == -1) & (updated_expert_maps_this_layer != -1)
|
||||||
|
)
|
||||||
|
|
||||||
# Parse expert_ids each rank needs to send to other ranks
|
# Parse expert_ids each rank needs to send to other ranks
|
||||||
src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \
|
src_rank_indices, experts_to_send = torch.where(
|
||||||
& (updated_expert_maps_this_layer == -1))
|
(current_expert_maps_this_layer != -1) & (updated_expert_maps_this_layer == -1)
|
||||||
|
)
|
||||||
|
|
||||||
for idx in range(len(dst_rank_indices)):
|
for idx in range(len(dst_rank_indices)):
|
||||||
dst_rank_id = dst_rank_indices[idx].item()
|
dst_rank_id = dst_rank_indices[idx].item()
|
||||||
@@ -152,27 +143,27 @@ class EplbWorker:
|
|||||||
if dst_rank_id not in expert_recv_info_this_layer:
|
if dst_rank_id not in expert_recv_info_this_layer:
|
||||||
expert_recv_info_this_layer[dst_rank_id] = []
|
expert_recv_info_this_layer[dst_rank_id] = []
|
||||||
|
|
||||||
if not torch.isin(torch.tensor(expert_id),
|
if not torch.isin(torch.tensor(expert_id), experts_to_send).any():
|
||||||
experts_to_send).any():
|
|
||||||
# if expert_id are not sent out from any npu, it will be copied from one npu holding this expert
|
# if expert_id are not sent out from any npu, it will be copied from one npu holding this expert
|
||||||
candidate_src_rank_indices = torch.where(
|
candidate_src_rank_indices = torch.where(current_expert_maps_this_layer[:, expert_id] != -1)[0]
|
||||||
current_expert_maps_this_layer[:, expert_id] != -1)[0]
|
|
||||||
else:
|
else:
|
||||||
candidate_src_rank_indices = src_rank_indices[
|
candidate_src_rank_indices = src_rank_indices[experts_to_send == expert_id]
|
||||||
experts_to_send == expert_id]
|
|
||||||
|
|
||||||
# TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node...
|
# TODO: improve selection criterion of NPU sending expert_id,
|
||||||
|
# considering intra-node or inter-node...
|
||||||
src_rank_id = candidate_src_rank_indices[0].item()
|
src_rank_id = candidate_src_rank_indices[0].item()
|
||||||
if src_rank_id not in expert_send_info_this_layer:
|
if src_rank_id not in expert_send_info_this_layer:
|
||||||
expert_send_info_this_layer[src_rank_id] = []
|
expert_send_info_this_layer[src_rank_id] = []
|
||||||
|
|
||||||
expert_send_info_this_layer[src_rank_id].append(
|
expert_send_info_this_layer[src_rank_id].append((dst_rank_id, expert_id))
|
||||||
(dst_rank_id, expert_id))
|
expert_recv_info_this_layer[dst_rank_id].append((src_rank_id, expert_id))
|
||||||
expert_recv_info_this_layer[dst_rank_id].append(
|
|
||||||
(src_rank_id, expert_id))
|
|
||||||
|
|
||||||
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
|
yield (
|
||||||
updated_expert_maps_this_layer, layer_id)
|
expert_send_info_this_layer,
|
||||||
|
expert_recv_info_this_layer,
|
||||||
|
updated_expert_maps_this_layer,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
def calculate_rebalance_experts(self, load_info, old_placement):
|
def calculate_rebalance_experts(self, load_info, old_placement):
|
||||||
"""
|
"""
|
||||||
@@ -181,8 +172,7 @@ class EplbWorker:
|
|||||||
if self.old_expert_maps is None:
|
if self.old_expert_maps is None:
|
||||||
return False, None, None
|
return False, None, None
|
||||||
|
|
||||||
changed, priority, new_map = self.policy.rebalance_experts(
|
changed, priority, new_map = self.policy.rebalance_experts(old_placement, load_info)
|
||||||
old_placement, load_info)
|
|
||||||
return changed, priority, new_map
|
return changed, priority, new_map
|
||||||
|
|
||||||
def get_init_expert_maps(self):
|
def get_init_expert_maps(self):
|
||||||
@@ -199,19 +189,13 @@ class EplbWorker:
|
|||||||
return self.shared_dict.get("moe_load", None)
|
return self.shared_dict.get("moe_load", None)
|
||||||
|
|
||||||
def update_expert_map(self, expert_maps):
|
def update_expert_map(self, expert_maps):
|
||||||
|
|
||||||
self.shared_dict["expert_maps"] = expert_maps
|
self.shared_dict["expert_maps"] = expert_maps
|
||||||
|
|
||||||
def global2local(self, placement: torch.Tensor,
|
def global2local(self, placement: torch.Tensor, E_local: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
E_local: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
|
|
||||||
L, G, _ = placement.shape
|
L, G, _ = placement.shape
|
||||||
device = placement.device
|
device = placement.device
|
||||||
|
|
||||||
pt_local = torch.full((L, G, E_local),
|
pt_local = torch.full((L, G, E_local), fill_value=-1, dtype=torch.long, device=device)
|
||||||
fill_value=-1,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
valid = placement >= 0
|
valid = placement >= 0
|
||||||
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
|
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
|
||||||
@@ -223,7 +207,6 @@ class EplbWorker:
|
|||||||
return pt_local
|
return pt_local
|
||||||
|
|
||||||
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
|
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
L, G, E_local = placement_local.shape
|
L, G, E_local = placement_local.shape
|
||||||
device = placement_local.device
|
device = placement_local.device
|
||||||
|
|
||||||
@@ -233,10 +216,7 @@ class EplbWorker:
|
|||||||
if E_global == 0:
|
if E_global == 0:
|
||||||
return torch.empty((L, G, 0), dtype=torch.long, device=device)
|
return torch.empty((L, G, 0), dtype=torch.long, device=device)
|
||||||
|
|
||||||
placement_global = torch.full((L, G, E_global),
|
placement_global = torch.full((L, G, E_global), fill_value=-1, dtype=torch.long, device=device)
|
||||||
fill_value=-1,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
valid = placement_local >= 0
|
valid = placement_local >= 0
|
||||||
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
|
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
|
||||||
@@ -257,11 +237,8 @@ class EplbWorker:
|
|||||||
layer_ids = []
|
layer_ids = []
|
||||||
|
|
||||||
for send_info, recv_info, new_expert_map, layer_id in update_info_generator:
|
for send_info, recv_info, new_expert_map, layer_id in update_info_generator:
|
||||||
|
send_info_this_rank = send_info.get(self.rank_id, [])
|
||||||
send_info_this_rank = send_info[
|
recv_info_this_rank = recv_info.get(self.rank_id, [])
|
||||||
self.rank_id] if self.rank_id in send_info else []
|
|
||||||
recv_info_this_rank = recv_info[
|
|
||||||
self.rank_id] if self.rank_id in recv_info else []
|
|
||||||
send_all.append(send_info_this_rank)
|
send_all.append(send_info_this_rank)
|
||||||
recv_all.append(recv_info_this_rank)
|
recv_all.append(recv_info_this_rank)
|
||||||
|
|
||||||
@@ -276,11 +253,7 @@ class EplbWorker:
|
|||||||
|
|
||||||
|
|
||||||
class EplbProcess:
|
class EplbProcess:
|
||||||
|
def __init__(self, shared_dict, policy_type: int = 0, enable_d2d: bool = True):
|
||||||
def __init__(self,
|
|
||||||
shared_dict,
|
|
||||||
policy_type: int = 0,
|
|
||||||
enable_d2d: bool = True):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
shared_dict: Cross-process shared dict returned by Manager().dict()
|
shared_dict: Cross-process shared dict returned by Manager().dict()
|
||||||
@@ -294,12 +267,12 @@ class EplbProcess:
|
|||||||
self.block_update_q: Queue[Any] = Queue(maxsize=1)
|
self.block_update_q: Queue[Any] = Queue(maxsize=1)
|
||||||
|
|
||||||
# Create EplbWorker instance
|
# Create EplbWorker instance
|
||||||
self.worker = EplbWorker(self.shared_dict, self.policy_type,
|
self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d)
|
||||||
self.enable_d2d)
|
|
||||||
|
|
||||||
def worker_process(self, planner_q, block_update_q):
|
def worker_process(self, planner_q, block_update_q):
|
||||||
"""
|
"""
|
||||||
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete.
|
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up,
|
||||||
|
call do_update, then notify main process update is complete.
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -314,17 +287,17 @@ class EplbProcess:
|
|||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[EPLB subprocess Exiting due to error: {e}",
|
logger.warning(
|
||||||
exc_info=True)
|
f"[EPLB subprocess exiting due to error: {e}]",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
def _launch_process(self):
|
def _launch_process(self):
|
||||||
"""
|
"""
|
||||||
Use spawn method to launch subprocess and return (planner_q, block_update_q, proc).
|
Use spawn method to launch subprocess and return (planner_q, block_update_q, proc).
|
||||||
"""
|
"""
|
||||||
proc = Process(target=self.worker_process,
|
proc = Process(target=self.worker_process, args=(self.planner_q, self.block_update_q), daemon=True)
|
||||||
args=(self.planner_q, self.block_update_q),
|
|
||||||
daemon=True)
|
|
||||||
|
|
||||||
proc.start()
|
proc.start()
|
||||||
return proc
|
return proc
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ class DynamicConfig:
|
|||||||
|
|
||||||
|
|
||||||
class EplbPolicy:
|
class EplbPolicy:
|
||||||
|
|
||||||
def __init__(self, config: DynamicConfig):
|
def __init__(self, config: DynamicConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|||||||
@@ -25,13 +25,11 @@ class DynamicTable:
|
|||||||
|
|
||||||
|
|
||||||
class DefaultEplb(EplbPolicy):
|
class DefaultEplb(EplbPolicy):
|
||||||
|
|
||||||
def __init__(self, config: DynamicConfig):
|
def __init__(self, config: DynamicConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_redundant(current_expert_table, expert_workload,
|
def add_redundant(current_expert_table, expert_workload, num_original_expert):
|
||||||
num_original_expert):
|
|
||||||
layer_num, npu_num, experts_per_npu = expert_workload.shape
|
layer_num, npu_num, experts_per_npu = expert_workload.shape
|
||||||
workload_new = np.zeros((layer_num, num_original_expert))
|
workload_new = np.zeros((layer_num, num_original_expert))
|
||||||
for layer_idx in range(layer_num):
|
for layer_idx in range(layer_num):
|
||||||
@@ -40,31 +38,24 @@ class DefaultEplb(EplbPolicy):
|
|||||||
workload_layer = expert_workload[layer_idx].copy()
|
workload_layer = expert_workload[layer_idx].copy()
|
||||||
for npu_idx in range(npu_num):
|
for npu_idx in range(npu_num):
|
||||||
for expert_idx in range(experts_per_npu):
|
for expert_idx in range(experts_per_npu):
|
||||||
workload_dict[placement_layer[npu_idx][
|
workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx]
|
||||||
expert_idx]] += workload_layer[npu_idx][expert_idx]
|
|
||||||
for expert_idx in range(num_original_expert):
|
for expert_idx in range(num_original_expert):
|
||||||
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
|
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
|
||||||
return workload_new
|
return workload_new
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Split hot (high-load) experts into redundant experts
|
# Split hot (high-load) experts into redundant experts
|
||||||
def original_compute_balanced_pack_redundancy(origin_weights, card_num,
|
def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert):
|
||||||
num_redundancy_expert):
|
|
||||||
# Step 1: Sort the items by weight in descending order (we are sorting by weight now)
|
# Step 1: Sort the items by weight in descending order (we are sorting by weight now)
|
||||||
# Sort based on the second element (the second value of each tuple)
|
# Sort based on the second element (the second value of each tuple)
|
||||||
route_expert_num = len(origin_weights)
|
route_expert_num = len(origin_weights)
|
||||||
route_expert_redundancy: list[list[int]] = [
|
route_expert_redundancy: list[list[int]] = [[] for _ in range(route_expert_num)]
|
||||||
[] for _ in range(route_expert_num)
|
|
||||||
]
|
|
||||||
for i in range(num_redundancy_expert):
|
for i in range(num_redundancy_expert):
|
||||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1]
|
||||||
kind='stable')[::-1]
|
|
||||||
weights = [origin_weights[idx] for idx in sorted_indices]
|
weights = [origin_weights[idx] for idx in sorted_indices]
|
||||||
tmp_raw_weight = weights[0][1] * (
|
tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
|
||||||
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
|
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
|
||||||
avg_weight = tmp_raw_weight / (
|
avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
|
||||||
weights[0] = (weights[0][0], avg_weight)
|
weights[0] = (weights[0][0], avg_weight)
|
||||||
origin_weights = weights
|
origin_weights = weights
|
||||||
|
|
||||||
@@ -93,8 +84,7 @@ class DefaultEplb(EplbPolicy):
|
|||||||
box_counts[index] += 1
|
box_counts[index] += 1
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1]
|
||||||
kind='stable')[::-1]
|
|
||||||
origin_weights = [origin_weights[idx] for idx in sorted_indices]
|
origin_weights = [origin_weights[idx] for idx in sorted_indices]
|
||||||
# Step 4: Distribute items into boxes based on weight
|
# Step 4: Distribute items into boxes based on weight
|
||||||
for item_id, weight in origin_weights:
|
for item_id, weight in origin_weights:
|
||||||
@@ -104,11 +94,8 @@ class DefaultEplb(EplbPolicy):
|
|||||||
if item_id in boxes[i]:
|
if item_id in boxes[i]:
|
||||||
continue
|
continue
|
||||||
# Only choose boxes that still have space (box_counts[i] < items_per_box)
|
# Only choose boxes that still have space (box_counts[i] < items_per_box)
|
||||||
if box_counts[i] < items_per_box or (box_counts[i]
|
if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0):
|
||||||
== items_per_box
|
if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]:
|
||||||
and remaining_items > 0):
|
|
||||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
|
||||||
min_box_index]:
|
|
||||||
min_box_index = i
|
min_box_index = i
|
||||||
|
|
||||||
# Place the item (id) into the selected box
|
# Place the item (id) into the selected box
|
||||||
@@ -118,40 +105,35 @@ class DefaultEplb(EplbPolicy):
|
|||||||
box_counts[min_box_index] += 1
|
box_counts[min_box_index] += 1
|
||||||
|
|
||||||
# If there's an imbalance in the remaining items, reduce the "remaining_items" counter
|
# If there's an imbalance in the remaining items, reduce the "remaining_items" counter
|
||||||
if box_counts[min_box_index] == (items_per_box +
|
if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0:
|
||||||
1) and remaining_items > 0:
|
|
||||||
remaining_items -= 1
|
remaining_items -= 1
|
||||||
|
|
||||||
# Step 5: Output each box's contents and total weight
|
# Step 5: Output each box's contents and total weight
|
||||||
result = []
|
result = []
|
||||||
for i in range(card_num):
|
for i in range(card_num):
|
||||||
result.append({
|
result.append(
|
||||||
"box_index": i + 1,
|
{
|
||||||
"items": boxes[i], # List of item IDs in the box
|
"box_index": i + 1,
|
||||||
"weight": boxes_weights[i],
|
"items": boxes[i], # List of item IDs in the box
|
||||||
"total_weight": box_weights[i], # Total weight in this box
|
"weight": boxes_weights[i],
|
||||||
"item_count": box_counts[i] # Number of items in the box
|
"total_weight": box_weights[i], # Total weight in this box
|
||||||
})
|
"item_count": box_counts[i], # Number of items in the box
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result, boxes
|
return result, boxes
|
||||||
|
|
||||||
# Split hot (high-load) experts into redundant experts
|
# Split hot (high-load) experts into redundant experts
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_balanced_pack_redundancy(origin_weights, card_num,
|
def compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert):
|
||||||
num_redundancy_expert):
|
|
||||||
route_expert_num = len(origin_weights)
|
route_expert_num = len(origin_weights)
|
||||||
route_expert_redundancy: list[list[int]] = [
|
route_expert_redundancy: list[list[int]] = [[] for _ in range(route_expert_num)]
|
||||||
[] for _ in range(route_expert_num)
|
|
||||||
]
|
|
||||||
for i in range(num_redundancy_expert):
|
for i in range(num_redundancy_expert):
|
||||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1]
|
||||||
kind='stable')[::-1]
|
|
||||||
weights = [origin_weights[idx] for idx in sorted_indices]
|
weights = [origin_weights[idx] for idx in sorted_indices]
|
||||||
tmp_raw_weight = weights[0][1] * (
|
tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
|
||||||
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
|
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
|
||||||
avg_weight = tmp_raw_weight / (
|
avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
|
||||||
weights[0] = (weights[0][0], avg_weight)
|
weights[0] = (weights[0][0], avg_weight)
|
||||||
origin_weights = weights
|
origin_weights = weights
|
||||||
|
|
||||||
@@ -166,7 +148,7 @@ class DefaultEplb(EplbPolicy):
|
|||||||
box_weights = [0] * card_num
|
box_weights = [0] * card_num
|
||||||
box_counts = [0] * card_num
|
box_counts = [0] * card_num
|
||||||
|
|
||||||
all_weights = np.zeros((expert_num, ), dtype='object')
|
all_weights = np.zeros((expert_num,), dtype="object")
|
||||||
all_weights[:route_expert_num] = origin_weights
|
all_weights[:route_expert_num] = origin_weights
|
||||||
|
|
||||||
index = route_expert_num
|
index = route_expert_num
|
||||||
@@ -178,17 +160,13 @@ class DefaultEplb(EplbPolicy):
|
|||||||
all_weights[index] = (item, weight)
|
all_weights[index] = (item, weight)
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
sorted_indices = np.argsort([t[1] for t in all_weights],
|
sorted_indices = np.argsort([t[1] for t in all_weights], kind="stable")[::-1]
|
||||||
kind='stable')[::-1]
|
|
||||||
all_weights = [all_weights[idx] for idx in sorted_indices]
|
all_weights = [all_weights[idx] for idx in sorted_indices]
|
||||||
for item_id, weight in all_weights:
|
for item_id, weight in all_weights:
|
||||||
min_box_index = -1
|
min_box_index = -1
|
||||||
for i in range(card_num):
|
for i in range(card_num):
|
||||||
if box_counts[i] < items_per_box or (box_counts[i]
|
if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0):
|
||||||
== items_per_box
|
if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]:
|
||||||
and remaining_items > 0):
|
|
||||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
|
||||||
min_box_index]:
|
|
||||||
if item_id not in boxes[i]:
|
if item_id not in boxes[i]:
|
||||||
min_box_index = i
|
min_box_index = i
|
||||||
|
|
||||||
@@ -197,19 +175,20 @@ class DefaultEplb(EplbPolicy):
|
|||||||
box_weights[min_box_index] += weight
|
box_weights[min_box_index] += weight
|
||||||
box_counts[min_box_index] += 1
|
box_counts[min_box_index] += 1
|
||||||
|
|
||||||
if box_counts[min_box_index] == (items_per_box +
|
if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0:
|
||||||
1) and remaining_items > 0:
|
|
||||||
remaining_items -= 1
|
remaining_items -= 1
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for i in range(card_num):
|
for i in range(card_num):
|
||||||
result.append({
|
result.append(
|
||||||
"box_index": i + 1,
|
{
|
||||||
"items": boxes[i],
|
"box_index": i + 1,
|
||||||
"weight": boxes_weights[i],
|
"items": boxes[i],
|
||||||
"total_weight": box_weights[i],
|
"weight": boxes_weights[i],
|
||||||
"item_count": box_counts[i]
|
"total_weight": box_weights[i],
|
||||||
})
|
"item_count": box_counts[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result, boxes
|
return result, boxes
|
||||||
|
|
||||||
@@ -232,11 +211,8 @@ class DefaultEplb(EplbPolicy):
|
|||||||
for item_id, weight in weights:
|
for item_id, weight in weights:
|
||||||
min_box_index = -1
|
min_box_index = -1
|
||||||
for i in range(card_num):
|
for i in range(card_num):
|
||||||
if box_counts[i] < items_per_box or (box_counts[i]
|
if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0):
|
||||||
== items_per_box
|
if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]:
|
||||||
and remaining_items > 0):
|
|
||||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
|
||||||
min_box_index]:
|
|
||||||
min_box_index = i
|
min_box_index = i
|
||||||
|
|
||||||
boxes[min_box_index].append(item_id)
|
boxes[min_box_index].append(item_id)
|
||||||
@@ -244,19 +220,20 @@ class DefaultEplb(EplbPolicy):
|
|||||||
box_weights[min_box_index] += weight
|
box_weights[min_box_index] += weight
|
||||||
box_counts[min_box_index] += 1
|
box_counts[min_box_index] += 1
|
||||||
|
|
||||||
if box_counts[min_box_index] == (items_per_box +
|
if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0:
|
||||||
1) and remaining_items > 0:
|
|
||||||
remaining_items -= 1
|
remaining_items -= 1
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for i in range(card_num):
|
for i in range(card_num):
|
||||||
result.append({
|
result.append(
|
||||||
"box_index": i + 1,
|
{
|
||||||
"items": boxes[i],
|
"box_index": i + 1,
|
||||||
"weight": boxes_weights[i],
|
"items": boxes[i],
|
||||||
"total_weight": box_weights[i],
|
"weight": boxes_weights[i],
|
||||||
"item_count": box_counts[i]
|
"total_weight": box_weights[i],
|
||||||
})
|
"item_count": box_counts[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result, boxes
|
return result, boxes
|
||||||
|
|
||||||
@@ -274,16 +251,11 @@ class DefaultEplb(EplbPolicy):
|
|||||||
return max_heat_per_layer
|
return max_heat_per_layer
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def constraint_expert_local_exchange(current_expert_table,
|
def constraint_expert_local_exchange(current_expert_table, global_deployment):
|
||||||
global_deployment):
|
|
||||||
for layer_id in range(len(global_deployment)):
|
for layer_id in range(len(global_deployment)):
|
||||||
for card_id in range(len(global_deployment[layer_id])):
|
for card_id in range(len(global_deployment[layer_id])):
|
||||||
current_list = [
|
current_list = [int(x) for x in current_expert_table[layer_id][card_id]]
|
||||||
int(x) for x in current_expert_table[layer_id][card_id]
|
new_list = [int(x) for x in global_deployment[layer_id][card_id]]
|
||||||
]
|
|
||||||
new_list = [
|
|
||||||
int(x) for x in global_deployment[layer_id][card_id]
|
|
||||||
]
|
|
||||||
num = len(new_list)
|
num = len(new_list)
|
||||||
|
|
||||||
new_index = [-1] * num
|
new_index = [-1] * num
|
||||||
@@ -293,8 +265,7 @@ class DefaultEplb(EplbPolicy):
|
|||||||
for i in range(num):
|
for i in range(num):
|
||||||
flag = True
|
flag = True
|
||||||
for j in range(num):
|
for j in range(num):
|
||||||
if new_list[i] == current_list[j] and new_index[
|
if new_list[i] == current_list[j] and new_index[j] == -1:
|
||||||
j] == -1:
|
|
||||||
new_index[j] = 0
|
new_index[j] = 0
|
||||||
new_result[j] = current_list[j]
|
new_result[j] = current_list[j]
|
||||||
flag = False
|
flag = False
|
||||||
@@ -313,7 +284,6 @@ class DefaultEplb(EplbPolicy):
|
|||||||
return global_deployment
|
return global_deployment
|
||||||
|
|
||||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||||
|
|
||||||
info = DynamicTable()
|
info = DynamicTable()
|
||||||
info.workload_table = np.array(expert_workload)
|
info.workload_table = np.array(expert_workload)
|
||||||
info.placement_table = np.array(current_expert_table)
|
info.placement_table = np.array(current_expert_table)
|
||||||
@@ -324,17 +294,15 @@ class DefaultEplb(EplbPolicy):
|
|||||||
expert_ids, counts = np.unique(row, return_counts=True)
|
expert_ids, counts = np.unique(row, return_counts=True)
|
||||||
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
|
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
|
||||||
num_original_expert = len(expert_ids)
|
num_original_expert = len(expert_ids)
|
||||||
layer_workloads = self.add_redundant(info.placement_table,
|
layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert)
|
||||||
info.workload_table,
|
max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num)
|
||||||
num_original_expert)
|
|
||||||
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
|
|
||||||
info.workload_table, layer_num)
|
|
||||||
npu_heat_all_origin = sum(max_heat_per_layer_before)
|
npu_heat_all_origin = sum(max_heat_per_layer_before)
|
||||||
|
|
||||||
# Perform load balancing and deploy redundant experts
|
# Perform load balancing and deploy redundant experts
|
||||||
layer_num = layer_workloads.shape[0]
|
layer_num = layer_workloads.shape[0]
|
||||||
expert_num = layer_workloads.shape[1]
|
expert_num = layer_workloads.shape[1]
|
||||||
# Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards
|
# Validate that the number of experts, number of cards, and number of redundant experts
|
||||||
|
# do not exceed the number of cards.
|
||||||
if num_original_expert != expert_num:
|
if num_original_expert != expert_num:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}"
|
f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}"
|
||||||
@@ -345,38 +313,35 @@ class DefaultEplb(EplbPolicy):
|
|||||||
|
|
||||||
if num_npus < num_redundancy_expert:
|
if num_npus < num_redundancy_expert:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}"
|
"the number of NPUs "
|
||||||
|
f"{num_npus} must be greater than or equal to the number of redundant experts "
|
||||||
|
f"{num_redundancy_expert}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Number of experts deployed on each card includes one redundant expert
|
# Number of experts deployed on each card includes one redundant expert
|
||||||
global_deployment: list[list[list[int]]] = [[[]
|
global_deployment: list[list[list[int]]] = [[[] for _ in range(num_npus)] for _ in range(layer_num)]
|
||||||
for _ in range(num_npus)]
|
|
||||||
for _ in range(layer_num)]
|
|
||||||
# Iterate to obtain the placement strategy for each layer, taking computational balance into account
|
# Iterate to obtain the placement strategy for each layer, taking computational balance into account
|
||||||
max_heat_per_layer_after = np.zeros([layer_num])
|
max_heat_per_layer_after = np.zeros([layer_num])
|
||||||
for layer in range(layer_num):
|
for layer in range(layer_num):
|
||||||
# Get the expert IDs and their corresponding workloads for the current layer;
|
# Get the expert IDs and their corresponding workloads for the current layer;
|
||||||
# workloads need to be normalized, and one redundant expert is added per card
|
# workloads need to be normalized, and one redundant expert is added per card
|
||||||
weights = np.zeros((expert_num, ), dtype='object')
|
weights = np.zeros((expert_num,), dtype="object")
|
||||||
for expert_id, workload_weight in enumerate(
|
for expert_id, workload_weight in enumerate(layer_workloads[layer]):
|
||||||
layer_workloads[layer]):
|
|
||||||
weights[expert_id] = (expert_id, workload_weight)
|
weights[expert_id] = (expert_id, workload_weight)
|
||||||
|
|
||||||
# Obtain the globally balanced placement strategy for each layer
|
# Obtain the globally balanced placement strategy for each layer
|
||||||
result, layer_deployment = self.original_compute_balanced_pack_redundancy(
|
result, layer_deployment = self.original_compute_balanced_pack_redundancy(
|
||||||
weights, num_npus, num_redundancy_expert)
|
weights, num_npus, num_redundancy_expert
|
||||||
|
)
|
||||||
|
|
||||||
global_deployment[layer] = layer_deployment
|
global_deployment[layer] = layer_deployment
|
||||||
max_heat_per_layer_after[layer] = max(
|
max_heat_per_layer_after[layer] = max(result, key=lambda x: x["total_weight"])["total_weight"]
|
||||||
result, key=lambda x: x['total_weight'])['total_weight']
|
|
||||||
|
|
||||||
new_global_deployment = self.constraint_expert_local_exchange(
|
new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment)
|
||||||
current_expert_table, global_deployment)
|
|
||||||
# Obtain the priority of each layer
|
# Obtain the priority of each layer
|
||||||
layer_changed_ratio = []
|
layer_changed_ratio = []
|
||||||
for layer_idx in range(layer_num):
|
for layer_idx in range(layer_num):
|
||||||
layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] /
|
layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx])
|
||||||
max_heat_per_layer_before[layer_idx])
|
|
||||||
|
|
||||||
per_layer_priority = np.argsort(layer_changed_ratio)
|
per_layer_priority = np.argsort(layer_changed_ratio)
|
||||||
npu_heat_all_after = sum(max_heat_per_layer_after)
|
npu_heat_all_after = sum(max_heat_per_layer_after)
|
||||||
@@ -385,5 +350,4 @@ class DefaultEplb(EplbPolicy):
|
|||||||
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
|
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
|
||||||
change = 1
|
change = 1
|
||||||
|
|
||||||
return change, per_layer_priority, np.array(
|
return change, per_layer_priority, np.array(new_global_deployment).tolist()
|
||||||
new_global_deployment).tolist()
|
|
||||||
|
|||||||
@@ -2,29 +2,26 @@
|
|||||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory.
|
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory.
|
||||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||||
from .policy_default_eplb import DefaultEplb
|
from .policy_default_eplb import DefaultEplb
|
||||||
from .policy_swift_balancer import SwiftBalanceEplb
|
|
||||||
from .policy_flashlb import FlashLB, warm_up
|
from .policy_flashlb import FlashLB, warm_up
|
||||||
from .policy_random import RandomLoadBalance
|
from .policy_random import RandomLoadBalance
|
||||||
|
from .policy_swift_balancer import SwiftBalanceEplb
|
||||||
|
|
||||||
|
|
||||||
class PolicyFactory:
|
class PolicyFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
|
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
|
||||||
policy = {
|
policy = {
|
||||||
# Constraint applying Dynamic EPLB policy V2:
|
# Constraint applying Dynamic EPLB policy V2:
|
||||||
# If there exists redundant expert:
|
# If there exists redundant expert:
|
||||||
# only one redundant expert can be placed in one NPU and its physical expert index must be 0
|
# only one redundant expert can be placed in one NPU and its physical expert index must be 0
|
||||||
|
|
||||||
# Applying greedy d2d expert weight update composing
|
# Applying greedy d2d expert weight update composing
|
||||||
0:
|
0: RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3
|
||||||
RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3
|
1: DefaultEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
|
||||||
1:
|
# Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
|
||||||
DefaultEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
|
2: SwiftBalanceEplb,
|
||||||
2:
|
# FlashLB EPLB policy: expert replacement based on Joint Optimization,
|
||||||
SwiftBalanceEplb, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
|
# Multi-Shot Enhancement and Incremental Adjustment
|
||||||
3:
|
3: FlashLB,
|
||||||
FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment
|
|
||||||
}
|
}
|
||||||
policy_class = policy.get(policy_type, RandomLoadBalance)
|
policy_class = policy.get(policy_type, RandomLoadBalance)
|
||||||
policy_instance = policy_class(config)
|
policy_instance = policy_class(config)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -45,8 +44,7 @@ def compute_piece_counts(X, P, stage_weights):
|
|||||||
secv = unit[i, idx2]
|
secv = unit[i, idx2]
|
||||||
alt = X[i, idx1] / (pieces[idx1] + 1)
|
alt = X[i, idx1] / (pieces[idx1] + 1)
|
||||||
delta = origin - (alt if alt > secv else secv)
|
delta = origin - (alt if alt > secv else secv)
|
||||||
deltas[idx1] += delta * stage_weights[i] if np.any(
|
deltas[idx1] += delta * stage_weights[i] if np.any(delta) != 0 else stage_weights[i]
|
||||||
delta) != 0 else stage_weights[i]
|
|
||||||
|
|
||||||
max_idx = np.argmax(deltas)
|
max_idx = np.argmax(deltas)
|
||||||
pieces[max_idx] += 1
|
pieces[max_idx] += 1
|
||||||
@@ -157,9 +155,7 @@ def jsq_placement(X, pieces, M, stage_weights):
|
|||||||
# Get elements already in current column
|
# Get elements already in current column
|
||||||
current_rank_elements = set(deployment[rank, :])
|
current_rank_elements = set(deployment[rank, :])
|
||||||
# Filter elements from range(N) not in current column
|
# Filter elements from range(N) not in current column
|
||||||
available = [
|
available = [x for x in range(N) if x not in current_rank_elements]
|
||||||
x for x in range(N) if x not in current_rank_elements
|
|
||||||
]
|
|
||||||
# Randomly select an available element to fill
|
# Randomly select an available element to fill
|
||||||
if len(available) > 0:
|
if len(available) > 0:
|
||||||
rand_idx = np.random.randint(0, len(available))
|
rand_idx = np.random.randint(0, len(available))
|
||||||
@@ -187,8 +183,7 @@ def slice_values(X, pieces):
|
|||||||
|
|
||||||
|
|
||||||
@njit
|
@njit
|
||||||
def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
|
def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, simulated_deployment, stage_weights):
|
||||||
simulated_deployment, stage_weights):
|
|
||||||
n_stage, N = X.shape
|
n_stage, N = X.shape
|
||||||
num_group = P // M
|
num_group = P // M
|
||||||
|
|
||||||
@@ -207,12 +202,11 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
|
|||||||
flat_deployment = simulated_deployment.reshape(-1)
|
flat_deployment = simulated_deployment.reshape(-1)
|
||||||
simulated_load = np.zeros(M, dtype=np.float32)
|
simulated_load = np.zeros(M, dtype=np.float32)
|
||||||
for i in range(flat_deployment.shape[0]):
|
for i in range(flat_deployment.shape[0]):
|
||||||
simulated_load[i // (flat_deployment.shape[0] //
|
simulated_load[i // (flat_deployment.shape[0] // M)] += unit_load[flat_deployment[i]]
|
||||||
M)] += unit_load[flat_deployment[i]]
|
|
||||||
|
|
||||||
slice_vals = slice_values(X_all, simulated_pieces)
|
slice_vals = slice_values(X_all, simulated_pieces)
|
||||||
sorted_slices = np.sort(slice_vals)[::-1]
|
sorted_slices = np.sort(slice_vals)[::-1]
|
||||||
simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M
|
simulated_slopes = (sorted_slices[: -M + 1] - sorted_slices[M - 1 :]) / M
|
||||||
|
|
||||||
cumulative_slices_used = np.zeros(N, dtype=np.int32)
|
cumulative_slices_used = np.zeros(N, dtype=np.int32)
|
||||||
acc = 0
|
acc = 0
|
||||||
@@ -230,8 +224,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
|
|||||||
slices_used_per_group = np.zeros(num_group, dtype=np.int32)
|
slices_used_per_group = np.zeros(num_group, dtype=np.int32)
|
||||||
slices_used_per_group[0] = group_boundary_indices[0]
|
slices_used_per_group[0] = group_boundary_indices[0]
|
||||||
for i in range(1, num_group):
|
for i in range(1, num_group):
|
||||||
slices_used_per_group[
|
slices_used_per_group[i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
|
||||||
i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
|
|
||||||
slices_used_per_group = M - slices_used_per_group
|
slices_used_per_group = M - slices_used_per_group
|
||||||
|
|
||||||
loads = np.zeros(M, dtype=np.float32)
|
loads = np.zeros(M, dtype=np.float32)
|
||||||
@@ -240,7 +233,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
|
|||||||
current_idx = 0
|
current_idx = 0
|
||||||
|
|
||||||
for g in range(num_group):
|
for g in range(num_group):
|
||||||
window = X_sorted[:, current_idx:current_idx + 2 * M]
|
window = X_sorted[:, current_idx : current_idx + 2 * M]
|
||||||
low = max(0, current_idx + M - N)
|
low = max(0, current_idx + M - N)
|
||||||
high = min(num_remain_slice, M - 1)
|
high = min(num_remain_slice, M - 1)
|
||||||
|
|
||||||
@@ -248,8 +241,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
|
|||||||
mid = int((high + low) // 2)
|
mid = int((high + low) // 2)
|
||||||
keep = M - mid
|
keep = M - mid
|
||||||
current_group = window[:, :keep]
|
current_group = window[:, :keep]
|
||||||
current_pieces = compute_piece_counts(current_group, M,
|
current_pieces = compute_piece_counts(current_group, M, stage_weights)
|
||||||
stage_weights)
|
|
||||||
current_pieces = np.maximum(current_pieces, 1)
|
current_pieces = np.maximum(current_pieces, 1)
|
||||||
current_slice = slice_values(current_group.sum(0), current_pieces)
|
current_slice = slice_values(current_group.sum(0), current_pieces)
|
||||||
current_slice_sorted = np.sort(current_slice)
|
current_slice_sorted = np.sort(current_slice)
|
||||||
@@ -257,8 +249,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
|
|||||||
current_max: np.float32 = np.max(current_loads)
|
current_max: np.float32 = np.max(current_loads)
|
||||||
current_min: np.float32 = np.min(current_loads)
|
current_min: np.float32 = np.min(current_loads)
|
||||||
current_slope = (current_max - current_min) / M
|
current_slope = (current_max - current_min) / M
|
||||||
next_slope: np.float32 = np.max(simulated_slopes[current_idx +
|
next_slope: np.float32 = np.max(simulated_slopes[current_idx + keep :])
|
||||||
keep:])
|
|
||||||
|
|
||||||
if abs(current_slope) > abs(next_slope):
|
if abs(current_slope) > abs(next_slope):
|
||||||
low = mid
|
low = mid
|
||||||
@@ -327,9 +318,7 @@ def auto_fix_new_placement(old_placement, new_placement):
|
|||||||
old_row = old_placement[rank_id]
|
old_row = old_placement[rank_id]
|
||||||
new_row = new_placement[rank_id]
|
new_row = new_placement[rank_id]
|
||||||
|
|
||||||
index_array = np.full((max_expert + 1, num_experts),
|
index_array = np.full((max_expert + 1, num_experts), -1, dtype=np.int32)
|
||||||
-1,
|
|
||||||
dtype=np.int32)
|
|
||||||
count_array = np.zeros(max_expert + 1, dtype=np.int32)
|
count_array = np.zeros(max_expert + 1, dtype=np.int32)
|
||||||
|
|
||||||
for idx in range(num_experts):
|
for idx in range(num_experts):
|
||||||
@@ -387,21 +376,17 @@ def auto_fix_new_placement(old_placement, new_placement):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLB(EplbPolicy):
|
class FlashLB(EplbPolicy):
|
||||||
|
|
||||||
def __init__(self, config: DynamicConfig):
|
def __init__(self, config: DynamicConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.par_history: Dict[int, float] = {}
|
self.par_history: dict[int, float] = {}
|
||||||
self.hotness_window: Dict[int, deque[float]] = {}
|
self.hotness_window: dict[int, deque[float]] = {}
|
||||||
self.max_stage_window = (config.max_stage_window if hasattr(
|
self.max_stage_window = config.max_stage_window if hasattr(config, "max_stage_window") else 1
|
||||||
config, "max_stage_window") else 1)
|
|
||||||
self.buffer_expert_layer_num = (
|
self.buffer_expert_layer_num = (
|
||||||
config.buffer_expert_layer_num if hasattr(
|
config.buffer_expert_layer_num if hasattr(config, "buffer_expert_layer_num") else 58
|
||||||
config, "buffer_expert_layer_num") else 58)
|
)
|
||||||
self.threshold_ratio = (config.threshold_ratio if hasattr(
|
self.threshold_ratio = config.threshold_ratio if hasattr(config, "threshold_ratio") else 0
|
||||||
config, "threshold_ratio") else 0)
|
|
||||||
|
|
||||||
def compute_expert_hotness(self, num_of_expert: int,
|
def compute_expert_hotness(self, num_of_expert: int, deployment: np.ndarray, rank_load: np.ndarray):
|
||||||
deployment: np.ndarray, rank_load: np.ndarray):
|
|
||||||
hotness = np.zeros(num_of_expert, dtype=rank_load.dtype)
|
hotness = np.zeros(num_of_expert, dtype=rank_load.dtype)
|
||||||
deployment_flat = deployment.ravel()
|
deployment_flat = deployment.ravel()
|
||||||
rank_load_flat = rank_load.ravel()
|
rank_load_flat = rank_load.ravel()
|
||||||
@@ -413,22 +398,14 @@ class FlashLB(EplbPolicy):
|
|||||||
if np.any(deployment < 0):
|
if np.any(deployment < 0):
|
||||||
raise ValueError("Deployment table contains negative values.")
|
raise ValueError("Deployment table contains negative values.")
|
||||||
counts = np.bincount(deployment.reshape(-1), minlength=N)
|
counts = np.bincount(deployment.reshape(-1), minlength=N)
|
||||||
unit_hotness = np.divide(hotness,
|
unit_hotness = np.divide(hotness, counts, out=np.zeros_like(hotness, dtype=float), where=counts != 0)
|
||||||
counts,
|
|
||||||
out=np.zeros_like(hotness, dtype=float),
|
|
||||||
where=counts != 0)
|
|
||||||
stage_par = np.zeros(n_stage)
|
stage_par = np.zeros(n_stage)
|
||||||
for i in range(n_stage):
|
for i in range(n_stage):
|
||||||
stage_load = unit_hotness[i][deployment].sum(-1)
|
stage_load = unit_hotness[i][deployment].sum(-1)
|
||||||
stage_par[i] = stage_load.max() / stage_load.mean()
|
stage_par[i] = stage_load.max() / stage_load.mean()
|
||||||
return stage_par.mean()
|
return stage_par.mean()
|
||||||
|
|
||||||
def group_based_adaptive_bloating(self,
|
def group_based_adaptive_bloating(self, X, P, M, stage_weights=None, recorsive=False):
|
||||||
X,
|
|
||||||
P,
|
|
||||||
M,
|
|
||||||
stage_weights=None,
|
|
||||||
recorsive=False):
|
|
||||||
n_stage, N = X.shape
|
n_stage, N = X.shape
|
||||||
if stage_weights is None:
|
if stage_weights is None:
|
||||||
stage_weights = np.ones(n_stage, dtype=np.float32)
|
stage_weights = np.ones(n_stage, dtype=np.float32)
|
||||||
@@ -437,15 +414,10 @@ class FlashLB(EplbPolicy):
|
|||||||
(
|
(
|
||||||
simulated_deployment,
|
simulated_deployment,
|
||||||
simulated_pieces,
|
simulated_pieces,
|
||||||
) = self.group_based_adaptive_bloating(X,
|
) = self.group_based_adaptive_bloating(X, P, M, stage_weights, recorsive=False)
|
||||||
P,
|
|
||||||
M,
|
|
||||||
stage_weights,
|
|
||||||
recorsive=False)
|
|
||||||
else:
|
else:
|
||||||
simulated_pieces = compute_piece_counts(X, P, stage_weights)
|
simulated_pieces = compute_piece_counts(X, P, stage_weights)
|
||||||
simulated_deployment = jsq_placement(X, simulated_pieces, M,
|
simulated_deployment = jsq_placement(X, simulated_pieces, M, stage_weights)
|
||||||
stage_weights)
|
|
||||||
|
|
||||||
pieces = group_based_adaptive_bloating_kernel(
|
pieces = group_based_adaptive_bloating_kernel(
|
||||||
X.astype(np.float32),
|
X.astype(np.float32),
|
||||||
@@ -459,10 +431,7 @@ class FlashLB(EplbPolicy):
|
|||||||
deployment = jsq_placement(X, pieces, M, stage_weights)
|
deployment = jsq_placement(X, pieces, M, stage_weights)
|
||||||
|
|
||||||
X_all = X.sum(0)
|
X_all = X.sum(0)
|
||||||
unit_load = np.divide(X_all,
|
unit_load = np.divide(X_all, pieces, out=np.zeros_like(X_all, dtype=float), where=pieces != 0)
|
||||||
pieces,
|
|
||||||
out=np.zeros_like(X_all, dtype=float),
|
|
||||||
where=pieces != 0)
|
|
||||||
load = unit_load[deployment].sum(-1)
|
load = unit_load[deployment].sum(-1)
|
||||||
|
|
||||||
sim_unit_load = X_all / simulated_pieces
|
sim_unit_load = X_all / simulated_pieces
|
||||||
@@ -510,16 +479,13 @@ class FlashLB(EplbPolicy):
|
|||||||
def register_hotness(self, deployment, rank_load, num_layer, num_expert):
|
def register_hotness(self, deployment, rank_load, num_layer, num_expert):
|
||||||
for layer in range(num_layer):
|
for layer in range(num_layer):
|
||||||
if layer not in self.hotness_window:
|
if layer not in self.hotness_window:
|
||||||
self.hotness_window[layer] = deque(
|
self.hotness_window[layer] = deque(maxlen=self.max_stage_window)
|
||||||
maxlen=self.max_stage_window)
|
hotness = self.compute_expert_hotness(num_expert, deployment[layer], rank_load[layer])
|
||||||
hotness = self.compute_expert_hotness(num_expert,
|
|
||||||
deployment[layer],
|
|
||||||
rank_load[layer])
|
|
||||||
self.hotness_window[layer].append(hotness)
|
self.hotness_window[layer].append(hotness)
|
||||||
|
|
||||||
def compress_by_avg_pooling_fast_nd(self, arr, m):
|
def compress_by_avg_pooling_fast_nd(self, arr, m):
|
||||||
n, d = arr.shape
|
n, d = arr.shape
|
||||||
idx = (np.arange(n) * m // n)
|
idx = np.arange(n) * m // n
|
||||||
result = np.zeros((m, d))
|
result = np.zeros((m, d))
|
||||||
counts = np.zeros((m, 1))
|
counts = np.zeros((m, 1))
|
||||||
np.add.at(result, idx, arr)
|
np.add.at(result, idx, arr)
|
||||||
@@ -532,8 +498,7 @@ class FlashLB(EplbPolicy):
|
|||||||
expert_workload += 1
|
expert_workload += 1
|
||||||
num_layer = expert_workload.shape[0]
|
num_layer = expert_workload.shape[0]
|
||||||
num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0]
|
num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0]
|
||||||
self.register_hotness(current_deployment, expert_workload, num_layer,
|
self.register_hotness(current_deployment, expert_workload, num_layer, num_expert)
|
||||||
num_expert)
|
|
||||||
|
|
||||||
new_deployment = current_deployment.copy()
|
new_deployment = current_deployment.copy()
|
||||||
|
|
||||||
@@ -544,21 +509,17 @@ class FlashLB(EplbPolicy):
|
|||||||
for i, layer in enumerate(layers_need_update):
|
for i, layer in enumerate(layers_need_update):
|
||||||
hotness = np.array(self.hotness_window[layer])
|
hotness = np.array(self.hotness_window[layer])
|
||||||
if hotness.shape[0] > self.max_stage_window:
|
if hotness.shape[0] > self.max_stage_window:
|
||||||
hotness = self.compress_by_avg_pooling_fast_nd(
|
hotness = self.compress_by_avg_pooling_fast_nd(hotness, self.max_stage_window)
|
||||||
hotness, self.max_stage_window)
|
|
||||||
|
|
||||||
(
|
(
|
||||||
new_deployment[layer],
|
new_deployment[layer],
|
||||||
new_par[i],
|
new_par[i],
|
||||||
current_par[i],
|
current_par[i],
|
||||||
) = self.rebalance_layer(current_deployment[layer],
|
) = self.rebalance_layer(current_deployment[layer], hotness, layer_id=layer)
|
||||||
hotness,
|
|
||||||
layer_id=layer)
|
|
||||||
|
|
||||||
priority = new_par / current_par
|
priority = new_par / current_par
|
||||||
priority_idx = np.argsort(priority)
|
priority_idx = np.argsort(priority)
|
||||||
priority_idx = priority_idx[priority[priority_idx] <
|
priority_idx = priority_idx[priority[priority_idx] < 1][: self.buffer_expert_layer_num]
|
||||||
1][:self.buffer_expert_layer_num]
|
|
||||||
|
|
||||||
if np.all(expert_workload == 1):
|
if np.all(expert_workload == 1):
|
||||||
for _, layer in enumerate(layers_need_update):
|
for _, layer in enumerate(layers_need_update):
|
||||||
@@ -572,16 +533,12 @@ class FlashLB(EplbPolicy):
|
|||||||
layers_need_update = priority_idx
|
layers_need_update = priority_idx
|
||||||
deployment = current_deployment
|
deployment = current_deployment
|
||||||
for layer in layers_need_update:
|
for layer in layers_need_update:
|
||||||
deployment[layer] = auto_fix_new_placement(
|
deployment[layer] = auto_fix_new_placement(current_deployment[layer], new_deployment[layer])
|
||||||
current_deployment[layer], new_deployment[layer])
|
|
||||||
|
|
||||||
return change, layers_need_update, deployment
|
return change, layers_need_update, deployment
|
||||||
|
|
||||||
|
|
||||||
def generate_layered_experts(num_layers=58,
|
def generate_layered_experts(num_layers=58, layer_shape=(32, 9), expert_min=0, expert_max=255):
|
||||||
layer_shape=(32, 9),
|
|
||||||
expert_min=0,
|
|
||||||
expert_max=255):
|
|
||||||
"""
|
"""
|
||||||
Generate expert deployment matrix meeting the following conditions:
|
Generate expert deployment matrix meeting the following conditions:
|
||||||
- Total of num_layers layers
|
- Total of num_layers layers
|
||||||
@@ -598,32 +555,25 @@ def generate_layered_experts(num_layers=58,
|
|||||||
"""
|
"""
|
||||||
# 1. Basic parameter calculation
|
# 1. Basic parameter calculation
|
||||||
expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255)
|
expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255)
|
||||||
layer_total = layer_shape[0] * layer_shape[
|
layer_total = layer_shape[0] * layer_shape[1] # Total elements in a single layer: 32*9=288
|
||||||
1] # Total elements in a single layer: 32*9=288
|
|
||||||
extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32
|
extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32
|
||||||
|
|
||||||
# 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts)
|
# 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts)
|
||||||
assert layer_total >= expert_num, (
|
assert layer_total >= expert_num, (
|
||||||
f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, "
|
f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, cannot cover all experts"
|
||||||
"cannot cover all experts")
|
)
|
||||||
|
|
||||||
# 3. Generate layers one by one
|
# 3. Generate layers one by one
|
||||||
layers = []
|
layers = []
|
||||||
for _ in range(num_layers):
|
for _ in range(num_layers):
|
||||||
# 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included)
|
# 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included)
|
||||||
full_experts = torch.arange(expert_min,
|
full_experts = torch.arange(expert_min, expert_max + 1, dtype=torch.int64) # shape (256,)
|
||||||
expert_max + 1,
|
|
||||||
dtype=torch.int64) # shape (256,)
|
|
||||||
|
|
||||||
# 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255)
|
# 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255)
|
||||||
extra_experts = torch.randint(expert_min,
|
extra_experts = torch.randint(expert_min, expert_max + 1, size=(extra_slots,), dtype=torch.int64) # shape (32,)
|
||||||
expert_max + 1,
|
|
||||||
size=(extra_slots, ),
|
|
||||||
dtype=torch.int64) # shape (32,)
|
|
||||||
|
|
||||||
# 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer)
|
# 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer)
|
||||||
layer_flat = torch.cat([full_experts, extra_experts],
|
layer_flat = torch.cat([full_experts, extra_experts], dim=0) # shape (288,)
|
||||||
dim=0) # shape (288,)
|
|
||||||
# Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues)
|
# Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues)
|
||||||
shuffle_idx = torch.randperm(layer_flat.shape[0])
|
shuffle_idx = torch.randperm(layer_flat.shape[0])
|
||||||
layer_shuffled = layer_flat[shuffle_idx]
|
layer_shuffled = layer_flat[shuffle_idx]
|
||||||
@@ -642,7 +592,6 @@ def warm_up():
|
|||||||
exam_config.num_die_per_host = 16
|
exam_config.num_die_per_host = 16
|
||||||
algo = FlashLB(exam_config)
|
algo = FlashLB(exam_config)
|
||||||
# Generate target tensor
|
# Generate target tensor
|
||||||
expert_tensor = generate_layered_experts(num_layers=58,
|
expert_tensor = generate_layered_experts(num_layers=58, layer_shape=(32, 9))
|
||||||
layer_shape=(32, 9))
|
|
||||||
|
|
||||||
algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))
|
algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ random.seed(42)
|
|||||||
|
|
||||||
|
|
||||||
class RandomLoadBalance(EplbPolicy):
|
class RandomLoadBalance(EplbPolicy):
|
||||||
|
|
||||||
def __init__(self, config: DynamicConfig):
|
def __init__(self, config: DynamicConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ class DynamicConfig:
|
|||||||
|
|
||||||
|
|
||||||
class EplbPolicy:
|
class EplbPolicy:
|
||||||
|
|
||||||
def __init__(self, config: DynamicConfig):
|
def __init__(self, config: DynamicConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@@ -63,7 +62,6 @@ class DynamicTable:
|
|||||||
|
|
||||||
|
|
||||||
class SwiftBalanceEplb(EplbPolicy):
|
class SwiftBalanceEplb(EplbPolicy):
|
||||||
|
|
||||||
def __init__(self, config: DynamicConfig):
|
def __init__(self, config: DynamicConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@@ -89,8 +87,7 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
return a % b
|
return a % b
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_redundant(current_expert_table, expert_workload,
|
def add_redundant(current_expert_table, expert_workload, num_original_expert):
|
||||||
num_original_expert):
|
|
||||||
layer_num, npu_num, experts_per_npu = expert_workload.shape
|
layer_num, npu_num, experts_per_npu = expert_workload.shape
|
||||||
workload_new = np.zeros((layer_num, num_original_expert))
|
workload_new = np.zeros((layer_num, num_original_expert))
|
||||||
for layer_idx in range(layer_num):
|
for layer_idx in range(layer_num):
|
||||||
@@ -99,8 +96,7 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
workload_layer = expert_workload[layer_idx].copy()
|
workload_layer = expert_workload[layer_idx].copy()
|
||||||
for npu_idx in range(npu_num):
|
for npu_idx in range(npu_num):
|
||||||
for expert_idx in range(experts_per_npu):
|
for expert_idx in range(experts_per_npu):
|
||||||
workload_dict[placement_layer[npu_idx][
|
workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx]
|
||||||
expert_idx]] += workload_layer[npu_idx][expert_idx]
|
|
||||||
for expert_idx in range(num_original_expert):
|
for expert_idx in range(num_original_expert):
|
||||||
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
|
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
|
||||||
return workload_new
|
return workload_new
|
||||||
@@ -118,9 +114,7 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
max_heat_per_layer.append(np.max(npu_heats_now))
|
max_heat_per_layer.append(np.max(npu_heats_now))
|
||||||
return max_heat_per_layer
|
return max_heat_per_layer
|
||||||
|
|
||||||
def calculate_initial_imbalance(self, global_deployment,
|
def calculate_initial_imbalance(self, global_deployment, new_layer_workloads):
|
||||||
new_layer_workloads):
|
|
||||||
|
|
||||||
device_num = global_deployment.shape[1]
|
device_num = global_deployment.shape[1]
|
||||||
layer_imbalance = []
|
layer_imbalance = []
|
||||||
expert_num = np.zeros_like(new_layer_workloads)
|
expert_num = np.zeros_like(new_layer_workloads)
|
||||||
@@ -136,56 +130,54 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
box_workload = 0
|
box_workload = 0
|
||||||
for expert_id in box:
|
for expert_id in box:
|
||||||
update_workload = self.safe_divide(
|
update_workload = self.safe_divide(
|
||||||
new_layer_workloads[layer_id][expert_id],
|
new_layer_workloads[layer_id][expert_id], expert_num[layer_id][expert_id]
|
||||||
expert_num[layer_id][expert_id])
|
)
|
||||||
box_workload += update_workload
|
box_workload += update_workload
|
||||||
total_workload += update_workload
|
total_workload += update_workload
|
||||||
if cur_layer_max_workload < box_workload:
|
if cur_layer_max_workload < box_workload:
|
||||||
cur_layer_max_workload = box_workload
|
cur_layer_max_workload = box_workload
|
||||||
|
|
||||||
cur_layer_imbalance = self.safe_divide(
|
cur_layer_imbalance = self.safe_divide(
|
||||||
cur_layer_max_workload,
|
cur_layer_max_workload, (self.safe_divide(total_workload, device_num))
|
||||||
(self.safe_divide(total_workload, device_num)))
|
)
|
||||||
layer_imbalance.append(cur_layer_imbalance)
|
layer_imbalance.append(cur_layer_imbalance)
|
||||||
|
|
||||||
return layer_imbalance
|
return layer_imbalance
|
||||||
|
|
||||||
def compute_redundant_assignments(self, base_experts,
|
def compute_redundant_assignments(self, base_experts, num_redundant_experts, num_experts):
|
||||||
num_redundant_experts, num_experts):
|
redundant_assignments: list[list[int]] = [[] for _ in range(num_experts)]
|
||||||
|
|
||||||
redundant_assignments: list[list[int]] = [[]
|
|
||||||
for _ in range(num_experts)]
|
|
||||||
current_weights = base_experts.copy()
|
current_weights = base_experts.copy()
|
||||||
|
|
||||||
for i in range(num_redundant_experts):
|
for i in range(num_redundant_experts):
|
||||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1]
|
||||||
kind='stable')[::-1]
|
|
||||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||||
|
|
||||||
target_expert = sorted_weights[0]
|
target_expert = sorted_weights[0]
|
||||||
expert_id, original_weight = target_expert
|
expert_id, original_weight = target_expert
|
||||||
|
|
||||||
current_redundancy = len(redundant_assignments[expert_id])
|
current_redundancy = len(redundant_assignments[expert_id])
|
||||||
new_avg_weight = self.safe_divide(
|
new_avg_weight = self.safe_divide(original_weight * (current_redundancy + 1), (current_redundancy + 2))
|
||||||
original_weight * (current_redundancy + 1),
|
|
||||||
(current_redundancy + 2))
|
|
||||||
|
|
||||||
redundant_assignments[expert_id].append(num_experts + i)
|
redundant_assignments[expert_id].append(num_experts + i)
|
||||||
current_weights[sorted_indices[0]] = (expert_id, new_avg_weight)
|
current_weights[sorted_indices[0]] = (expert_id, new_avg_weight)
|
||||||
|
|
||||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1]
|
||||||
kind='stable')[::-1]
|
|
||||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||||
|
|
||||||
return redundant_assignments, sorted_weights
|
return redundant_assignments, sorted_weights
|
||||||
|
|
||||||
def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos,
|
def repeat_compute_redundant_assignments(
|
||||||
num_experts, num_exist_expert,
|
self,
|
||||||
device_assignments, device_counts,
|
layer_workloads,
|
||||||
expert_from_device,
|
rendun_pos,
|
||||||
com_between_devices):
|
num_experts,
|
||||||
|
num_exist_expert,
|
||||||
current_weights = np.zeros((num_experts, ), dtype='object')
|
device_assignments,
|
||||||
|
device_counts,
|
||||||
|
expert_from_device,
|
||||||
|
com_between_devices,
|
||||||
|
):
|
||||||
|
current_weights = np.zeros((num_experts,), dtype="object")
|
||||||
for expert_id, workload_weight in enumerate(layer_workloads):
|
for expert_id, workload_weight in enumerate(layer_workloads):
|
||||||
current_weights[expert_id] = (expert_id, workload_weight)
|
current_weights[expert_id] = (expert_id, workload_weight)
|
||||||
|
|
||||||
@@ -195,8 +187,7 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
devices_with_slots.append(device_id)
|
devices_with_slots.append(device_id)
|
||||||
|
|
||||||
while devices_with_slots:
|
while devices_with_slots:
|
||||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1]
|
||||||
kind='stable')[::-1]
|
|
||||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||||
|
|
||||||
for index, target_weight in enumerate(sorted_weights):
|
for index, target_weight in enumerate(sorted_weights):
|
||||||
@@ -210,17 +201,15 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
pos = rendun_pos[cur_device_id].pop()
|
pos = rendun_pos[cur_device_id].pop()
|
||||||
if len(rendun_pos[cur_device_id]) == 0:
|
if len(rendun_pos[cur_device_id]) == 0:
|
||||||
devices_with_slots = [
|
devices_with_slots = [
|
||||||
device_id for device_id in devices_with_slots
|
device_id for device_id in devices_with_slots if device_id != cur_device_id
|
||||||
if device_id != cur_device_id
|
|
||||||
]
|
]
|
||||||
device_assignments[cur_device_id][pos] = expert_id
|
device_assignments[cur_device_id][pos] = expert_id
|
||||||
device_counts[cur_device_id] += 1
|
device_counts[cur_device_id] += 1
|
||||||
communication_box_index = expert_from_device[expert_id]
|
communication_box_index = expert_from_device[expert_id]
|
||||||
com_between_devices[cur_device_id][
|
com_between_devices[cur_device_id][communication_box_index] = expert_id
|
||||||
communication_box_index] = expert_id
|
|
||||||
new_weight = self.safe_divide(
|
new_weight = self.safe_divide(
|
||||||
(original_weight * num_exist_expert[expert_id]),
|
(original_weight * num_exist_expert[expert_id]), (num_exist_expert[expert_id] + 1)
|
||||||
(num_exist_expert[expert_id] + 1))
|
)
|
||||||
sorted_weights[index] = (expert_id, new_weight)
|
sorted_weights[index] = (expert_id, new_weight)
|
||||||
num_exist_expert[expert_id] += 1
|
num_exist_expert[expert_id] += 1
|
||||||
redundancy_successful = True
|
redundancy_successful = True
|
||||||
@@ -228,41 +217,31 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
if redundancy_successful:
|
if redundancy_successful:
|
||||||
break
|
break
|
||||||
|
|
||||||
sorted_indices = np.argsort([id for id, _ in sorted_weights],
|
sorted_indices = np.argsort([id for id, _ in sorted_weights], kind="stable")
|
||||||
kind='stable')
|
|
||||||
sorted_weights = [sorted_weights[i][1] for i in sorted_indices]
|
sorted_weights = [sorted_weights[i][1] for i in sorted_indices]
|
||||||
|
|
||||||
return sorted_weights, device_assignments, device_counts, com_between_devices
|
return sorted_weights, device_assignments, device_counts, com_between_devices
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_expert_list(base_experts, redundant_assignments,
|
def prepare_expert_list(base_experts, redundant_assignments, num_redundant_experts):
|
||||||
num_redundant_experts):
|
|
||||||
redundant_expert_list = np.empty(num_redundant_experts, dtype=object)
|
redundant_expert_list = np.empty(num_redundant_experts, dtype=object)
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
num_experts = len(redundant_assignments)
|
num_experts = len(redundant_assignments)
|
||||||
for expert_id in range(num_experts):
|
for expert_id in range(num_experts):
|
||||||
for _ in redundant_assignments[expert_id]:
|
for _ in redundant_assignments[expert_id]:
|
||||||
redundant_expert_list[index] = (expert_id,
|
redundant_expert_list[index] = (expert_id, next(w for eid, w in base_experts if eid == expert_id))
|
||||||
next(w
|
|
||||||
for eid, w in base_experts
|
|
||||||
if eid == expert_id))
|
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
sorted_indices = np.argsort([w for _, w in redundant_expert_list],
|
sorted_indices = np.argsort([w for _, w in redundant_expert_list], kind="stable")[::-1]
|
||||||
kind='stable')[::-1]
|
|
||||||
return [redundant_expert_list[i] for i in sorted_indices]
|
return [redundant_expert_list[i] for i in sorted_indices]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def non_redundant_expert_information(origin_deployment, updated_weights,
|
def non_redundant_expert_information(origin_deployment, updated_weights, rendun_pos):
|
||||||
rendun_pos):
|
|
||||||
|
|
||||||
device_num = len(origin_deployment)
|
device_num = len(origin_deployment)
|
||||||
num_experts_per_device = origin_deployment.shape[1]
|
num_experts_per_device = origin_deployment.shape[1]
|
||||||
device_assignments = [[-1 for _ in range(num_experts_per_device)]
|
device_assignments = [[-1 for _ in range(num_experts_per_device)] for _ in range(device_num)]
|
||||||
for _ in range(device_num)]
|
device_weights = [[0 for _ in range(num_experts_per_device)] for _ in range(device_num)]
|
||||||
device_weights = [[0 for _ in range(num_experts_per_device)]
|
|
||||||
for _ in range(device_num)]
|
|
||||||
device_loads = [0] * device_num
|
device_loads = [0] * device_num
|
||||||
device_counts = [0] * device_num
|
device_counts = [0] * device_num
|
||||||
|
|
||||||
@@ -272,8 +251,8 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
continue
|
continue
|
||||||
device_assignments[device_id][index] = expert_id
|
device_assignments[device_id][index] = expert_id
|
||||||
cur_weight = next(
|
cur_weight = next(
|
||||||
weight for expert_id_of_weight, weight in updated_weights
|
weight for expert_id_of_weight, weight in updated_weights if expert_id_of_weight == expert_id
|
||||||
if expert_id_of_weight == expert_id)
|
)
|
||||||
device_weights[device_id][index] = cur_weight
|
device_weights[device_id][index] = cur_weight
|
||||||
device_loads[device_id] += cur_weight
|
device_loads[device_id] += cur_weight
|
||||||
device_counts[device_id] += 1
|
device_counts[device_id] += 1
|
||||||
@@ -292,19 +271,24 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
if num_all_experts[expert_id] == 0:
|
if num_all_experts[expert_id] == 0:
|
||||||
cur_layer_workload.append(-1)
|
cur_layer_workload.append(-1)
|
||||||
else:
|
else:
|
||||||
cur_layer_workload.append(
|
cur_layer_workload.append(self.safe_divide(weight, num_all_experts[expert_id]))
|
||||||
self.safe_divide(weight, num_all_experts[expert_id]))
|
|
||||||
|
|
||||||
return cur_layer_workload, num_all_experts
|
return cur_layer_workload, num_all_experts
|
||||||
|
|
||||||
def distribute_redun_experts(self, layer_workloads, device_assignments,
|
def distribute_redun_experts(
|
||||||
device_weights, device_loads, device_counts,
|
self,
|
||||||
redundant_expert_list, expert_from_device,
|
layer_workloads,
|
||||||
num_experts, rendun_pos):
|
device_assignments,
|
||||||
|
device_weights,
|
||||||
|
device_loads,
|
||||||
|
device_counts,
|
||||||
|
redundant_expert_list,
|
||||||
|
expert_from_device,
|
||||||
|
num_experts,
|
||||||
|
rendun_pos,
|
||||||
|
):
|
||||||
num_devices = len(device_assignments)
|
num_devices = len(device_assignments)
|
||||||
com_between_devices: list[dict[int,
|
com_between_devices: list[dict[int, int]] = [{} for _ in range(num_devices)]
|
||||||
int]] = [{} for _ in range(num_devices)]
|
|
||||||
|
|
||||||
for expert_id, weight in redundant_expert_list:
|
for expert_id, weight in redundant_expert_list:
|
||||||
candidate = -1
|
candidate = -1
|
||||||
@@ -313,8 +297,7 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
continue
|
continue
|
||||||
if expert_id in device_assignments[dev_id]:
|
if expert_id in device_assignments[dev_id]:
|
||||||
continue
|
continue
|
||||||
if candidate == -1 or device_loads[dev_id] < device_loads[
|
if candidate == -1 or device_loads[dev_id] < device_loads[candidate]:
|
||||||
candidate]:
|
|
||||||
candidate = dev_id
|
candidate = dev_id
|
||||||
if candidate != -1:
|
if candidate != -1:
|
||||||
pos = rendun_pos[candidate].pop()
|
pos = rendun_pos[candidate].pop()
|
||||||
@@ -324,31 +307,42 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
device_counts[candidate] += 1
|
device_counts[candidate] += 1
|
||||||
|
|
||||||
communication_box_index = expert_from_device[expert_id]
|
communication_box_index = expert_from_device[expert_id]
|
||||||
com_between_devices[candidate][
|
com_between_devices[candidate][communication_box_index] = expert_id
|
||||||
communication_box_index] = expert_id
|
|
||||||
|
|
||||||
if any(sublist for sublist in rendun_pos):
|
if any(sublist for sublist in rendun_pos):
|
||||||
cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(
|
cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(layer_workloads, device_assignments)
|
||||||
layer_workloads, device_assignments)
|
|
||||||
|
|
||||||
update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments(
|
update_workload, device_assignments, device_counts, com_between_devices = (
|
||||||
cur_layer_workload, rendun_pos, num_experts, num_exist_expert,
|
self.repeat_compute_redundant_assignments(
|
||||||
device_assignments, device_loads, expert_from_device,
|
cur_layer_workload,
|
||||||
com_between_devices)
|
rendun_pos,
|
||||||
|
num_experts,
|
||||||
|
num_exist_expert,
|
||||||
|
device_assignments,
|
||||||
|
device_loads,
|
||||||
|
expert_from_device,
|
||||||
|
com_between_devices,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
device_loads = [0] * len(device_counts)
|
device_loads = [0] * len(device_counts)
|
||||||
for device_id, device in enumerate(device_assignments):
|
for device_id, device in enumerate(device_assignments):
|
||||||
for index, expert_id in enumerate(device):
|
for index, expert_id in enumerate(device):
|
||||||
device_weights[device_id][index] = update_workload[
|
device_weights[device_id][index] = update_workload[expert_id]
|
||||||
expert_id]
|
|
||||||
device_loads[device_id] += update_workload[expert_id]
|
device_loads[device_id] += update_workload[expert_id]
|
||||||
|
|
||||||
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
|
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
|
||||||
|
|
||||||
def redundancy_again(self, layer_workloads, origin_weights,
|
def redundancy_again(
|
||||||
origin_deployment, expert_from_device, num_node,
|
self,
|
||||||
is_node_redundant, rendun_pos):
|
layer_workloads,
|
||||||
|
origin_weights,
|
||||||
|
origin_deployment,
|
||||||
|
expert_from_device,
|
||||||
|
num_node,
|
||||||
|
is_node_redundant,
|
||||||
|
rendun_pos,
|
||||||
|
):
|
||||||
num_experts = len(origin_weights)
|
num_experts = len(origin_weights)
|
||||||
if is_node_redundant:
|
if is_node_redundant:
|
||||||
num_experts = num_experts * num_node
|
num_experts = num_experts * num_node
|
||||||
@@ -358,25 +352,33 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
num_redundant_experts += len(rank_empty_pos)
|
num_redundant_experts += len(rank_empty_pos)
|
||||||
|
|
||||||
redundant_assignments, updated_weights = self.compute_redundant_assignments(
|
redundant_assignments, updated_weights = self.compute_redundant_assignments(
|
||||||
origin_weights, num_redundant_experts, num_experts)
|
origin_weights, num_redundant_experts, num_experts
|
||||||
|
)
|
||||||
|
|
||||||
redundant_expert_list = self.prepare_expert_list(
|
redundant_expert_list = self.prepare_expert_list(updated_weights, redundant_assignments, num_redundant_experts)
|
||||||
updated_weights, redundant_assignments, num_redundant_experts)
|
|
||||||
|
|
||||||
device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information(
|
device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information(
|
||||||
origin_deployment, updated_weights, rendun_pos)
|
origin_deployment, updated_weights, rendun_pos
|
||||||
|
)
|
||||||
|
|
||||||
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts(
|
device_assignments, device_weights, device_loads, device_counts, com_between_devices = (
|
||||||
layer_workloads, device_assignments, device_weights, device_loads,
|
self.distribute_redun_experts(
|
||||||
device_counts, redundant_expert_list, expert_from_device,
|
layer_workloads,
|
||||||
num_experts, rendun_pos)
|
device_assignments,
|
||||||
|
device_weights,
|
||||||
|
device_loads,
|
||||||
|
device_counts,
|
||||||
|
redundant_expert_list,
|
||||||
|
expert_from_device,
|
||||||
|
num_experts,
|
||||||
|
rendun_pos,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
|
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_allocation_report(device_assignments, device_weights,
|
def generate_allocation_report(device_assignments, device_weights, device_loads, device_counts):
|
||||||
device_loads, device_counts):
|
|
||||||
|
|
||||||
report = []
|
report = []
|
||||||
max_load = 0.0
|
max_load = 0.0
|
||||||
|
|
||||||
@@ -384,27 +386,27 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
current_load = device_loads[dev_id]
|
current_load = device_loads[dev_id]
|
||||||
max_load = max(max_load, current_load)
|
max_load = max(max_load, current_load)
|
||||||
|
|
||||||
report.append({
|
report.append(
|
||||||
"device_id": dev_id + 1,
|
{
|
||||||
"assigned_experts": device_assignments[dev_id],
|
"device_id": dev_id + 1,
|
||||||
"expert_weights": device_weights[dev_id],
|
"assigned_experts": device_assignments[dev_id],
|
||||||
"total_load": current_load,
|
"expert_weights": device_weights[dev_id],
|
||||||
"expert_count": device_counts[dev_id]
|
"total_load": current_load,
|
||||||
})
|
"expert_count": device_counts[dev_id],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return report, max_load
|
return report, max_load
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id,
|
def exchange_expert(
|
||||||
next_device_id, cur_layer_result, com_between_devices):
|
cur_exchange_index, next_exchange_index, cur_device_id, next_device_id, cur_layer_result, com_between_devices
|
||||||
|
):
|
||||||
|
cur_device_deployment = cur_layer_result[cur_device_id]["assigned_experts"]
|
||||||
|
next_device_deployment = cur_layer_result[next_device_id]["assigned_experts"]
|
||||||
|
|
||||||
cur_device_deployment = cur_layer_result[cur_device_id][
|
cur_device_weight = cur_layer_result[cur_device_id]["expert_weights"]
|
||||||
'assigned_experts']
|
next_device_weight = cur_layer_result[next_device_id]["expert_weights"]
|
||||||
next_device_deployment = cur_layer_result[next_device_id][
|
|
||||||
'assigned_experts']
|
|
||||||
|
|
||||||
cur_device_weight = cur_layer_result[cur_device_id]['expert_weights']
|
|
||||||
next_device_weight = cur_layer_result[next_device_id]['expert_weights']
|
|
||||||
|
|
||||||
cur_expert_id = cur_device_deployment[cur_exchange_index]
|
cur_expert_id = cur_device_deployment[cur_exchange_index]
|
||||||
next_expert_id = next_device_deployment[next_exchange_index]
|
next_expert_id = next_device_deployment[next_exchange_index]
|
||||||
@@ -416,29 +418,25 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
cur_device_weight[cur_exchange_index] = next_expert_weight
|
cur_device_weight[cur_exchange_index] = next_expert_weight
|
||||||
next_device_weight[next_exchange_index] = cur_expert_weight
|
next_device_weight[next_exchange_index] = cur_expert_weight
|
||||||
|
|
||||||
cur_layer_result[cur_device_id][
|
cur_layer_result[cur_device_id]["total_load"] += next_expert_weight - cur_expert_weight
|
||||||
'total_load'] += next_expert_weight - cur_expert_weight
|
cur_layer_result[next_device_id]["total_load"] += cur_expert_weight - next_expert_weight
|
||||||
cur_layer_result[next_device_id][
|
|
||||||
'total_load'] += cur_expert_weight - next_expert_weight
|
|
||||||
|
|
||||||
com_between_devices[cur_device_id][next_device_id] = next_expert_id
|
com_between_devices[cur_device_id][next_device_id] = next_expert_id
|
||||||
com_between_devices[next_device_id][cur_device_id] = cur_expert_id
|
com_between_devices[next_device_id][cur_device_id] = cur_expert_id
|
||||||
|
|
||||||
def redundant_expert_deployment(self, layer_workloads, original_deployment,
|
def redundant_expert_deployment(
|
||||||
expert_from_device, node_num,
|
self, layer_workloads, original_deployment, expert_from_device, node_num, is_node_redundant, rendun_pos
|
||||||
is_node_redundant, rendun_pos):
|
):
|
||||||
device_num, per_device_expert_num = original_deployment.shape
|
device_num, per_device_expert_num = original_deployment.shape
|
||||||
route_expert_num = layer_workloads.shape[0]
|
route_expert_num = layer_workloads.shape[0]
|
||||||
per_node_device_num = self.safe_exact_divide(device_num, node_num)
|
per_node_device_num = self.safe_exact_divide(device_num, node_num)
|
||||||
per_node_route_expert_num = per_node_device_num * (
|
per_node_route_expert_num = per_node_device_num * (per_device_expert_num - 1)
|
||||||
per_device_expert_num - 1)
|
|
||||||
|
|
||||||
weights = np.zeros((route_expert_num, ), dtype='object')
|
weights = np.zeros((route_expert_num,), dtype="object")
|
||||||
for expert_id, workload_weight in enumerate(layer_workloads):
|
for expert_id, workload_weight in enumerate(layer_workloads):
|
||||||
weights[expert_id] = (expert_id, workload_weight)
|
weights[expert_id] = (expert_id, workload_weight)
|
||||||
|
|
||||||
if is_node_redundant:
|
if is_node_redundant:
|
||||||
|
|
||||||
device_assignments = []
|
device_assignments = []
|
||||||
device_weights = []
|
device_weights = []
|
||||||
device_loads = []
|
device_loads = []
|
||||||
@@ -446,23 +444,30 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
com_between_devices = []
|
com_between_devices = []
|
||||||
|
|
||||||
for node_id in range(node_num):
|
for node_id in range(node_num):
|
||||||
cur_node_weights = weights[node_id *
|
cur_node_weights = weights[
|
||||||
per_node_route_expert_num:(node_id +
|
node_id * per_node_route_expert_num : (node_id + 1) * per_node_route_expert_num
|
||||||
1) *
|
]
|
||||||
per_node_route_expert_num]
|
|
||||||
cur_original_deployment = original_deployment[
|
cur_original_deployment = original_deployment[
|
||||||
node_id * per_node_device_num:(node_id + 1) *
|
node_id * per_node_device_num : (node_id + 1) * per_node_device_num
|
||||||
per_node_device_num]
|
]
|
||||||
|
|
||||||
cur_node_rendun_pos = rendun_pos[node_id *
|
cur_node_rendun_pos = rendun_pos[node_id * per_node_device_num : (node_id + 1) * per_node_device_num]
|
||||||
per_node_device_num:(node_id +
|
|
||||||
1) *
|
|
||||||
per_node_device_num]
|
|
||||||
|
|
||||||
cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again(
|
(
|
||||||
layer_workloads, cur_node_weights, cur_original_deployment,
|
cur_device_assignments,
|
||||||
expert_from_device, node_num, is_node_redundant,
|
cur_device_weights,
|
||||||
cur_node_rendun_pos)
|
cur_device_loads,
|
||||||
|
cur_device_counts,
|
||||||
|
cur_com_between_devices,
|
||||||
|
) = self.redundancy_again(
|
||||||
|
layer_workloads,
|
||||||
|
cur_node_weights,
|
||||||
|
cur_original_deployment,
|
||||||
|
expert_from_device,
|
||||||
|
node_num,
|
||||||
|
is_node_redundant,
|
||||||
|
cur_node_rendun_pos,
|
||||||
|
)
|
||||||
device_assignments += cur_device_assignments
|
device_assignments += cur_device_assignments
|
||||||
device_weights += cur_device_weights
|
device_weights += cur_device_weights
|
||||||
device_loads += cur_device_loads
|
device_loads += cur_device_loads
|
||||||
@@ -470,28 +475,41 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
com_between_devices += cur_com_between_devices
|
com_between_devices += cur_com_between_devices
|
||||||
|
|
||||||
else:
|
else:
|
||||||
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again(
|
device_assignments, device_weights, device_loads, device_counts, com_between_devices = (
|
||||||
layer_workloads, weights, original_deployment,
|
self.redundancy_again(
|
||||||
expert_from_device, node_num, is_node_redundant, rendun_pos)
|
layer_workloads,
|
||||||
|
weights,
|
||||||
|
original_deployment,
|
||||||
|
expert_from_device,
|
||||||
|
node_num,
|
||||||
|
is_node_redundant,
|
||||||
|
rendun_pos,
|
||||||
|
)
|
||||||
|
)
|
||||||
report, max_load = self.generate_allocation_report(
|
report, max_load = self.generate_allocation_report(
|
||||||
device_assignments, device_weights, device_loads, device_counts)
|
device_assignments, device_weights, device_loads, device_counts
|
||||||
|
)
|
||||||
|
|
||||||
return report, max_load, com_between_devices
|
return report, max_load, com_between_devices
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def two_device_exchange_experts(cur_device_result, exchange_device_result,
|
def two_device_exchange_experts(
|
||||||
cur_exchanged_expert_id,
|
cur_device_result,
|
||||||
next_exchanged_expert_id, ave_workload,
|
exchange_device_result,
|
||||||
increment, num_redundancy_expert):
|
cur_exchanged_expert_id,
|
||||||
|
next_exchanged_expert_id,
|
||||||
|
ave_workload,
|
||||||
|
increment,
|
||||||
|
num_redundancy_expert,
|
||||||
|
):
|
||||||
|
cur_device_weight = cur_device_result["expert_weights"]
|
||||||
|
next_device_weight = exchange_device_result["expert_weights"]
|
||||||
|
|
||||||
cur_device_weight = cur_device_result['expert_weights']
|
cur_device_expert_id = cur_device_result["assigned_experts"]
|
||||||
next_device_weight = exchange_device_result['expert_weights']
|
next_device_expert_id = exchange_device_result["assigned_experts"]
|
||||||
|
|
||||||
cur_device_expert_id = cur_device_result['assigned_experts']
|
cur_device_total_weight = cur_device_result["total_load"]
|
||||||
next_device_expert_id = exchange_device_result['assigned_experts']
|
next_device_total_weight = exchange_device_result["total_load"]
|
||||||
|
|
||||||
cur_device_total_weight = cur_device_result['total_load']
|
|
||||||
next_device_total_weight = exchange_device_result['total_load']
|
|
||||||
max_weight = max(cur_device_total_weight, next_device_total_weight)
|
max_weight = max(cur_device_total_weight, next_device_total_weight)
|
||||||
|
|
||||||
cur_exchange_index = -1
|
cur_exchange_index = -1
|
||||||
@@ -500,50 +518,47 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
for index, weight in enumerate(cur_device_weight):
|
for index, weight in enumerate(cur_device_weight):
|
||||||
for next_index, next_weight in enumerate(next_device_weight):
|
for next_index, next_weight in enumerate(next_device_weight):
|
||||||
change_flag = True
|
change_flag = True
|
||||||
if (cur_device_expert_id[index] in next_device_expert_id
|
if (
|
||||||
or next_device_expert_id[next_index]
|
cur_device_expert_id[index] in next_device_expert_id
|
||||||
in cur_device_expert_id):
|
or next_device_expert_id[next_index] in cur_device_expert_id
|
||||||
|
):
|
||||||
change_flag = False
|
change_flag = False
|
||||||
if (cur_device_expert_id[index] not in cur_exchanged_expert_id
|
if (
|
||||||
) and (next_device_expert_id[next_index]
|
(cur_device_expert_id[index] not in cur_exchanged_expert_id)
|
||||||
not in next_exchanged_expert_id) and change_flag:
|
and (next_device_expert_id[next_index] not in next_exchanged_expert_id)
|
||||||
|
and change_flag
|
||||||
|
):
|
||||||
cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight
|
cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight
|
||||||
next_total_weight_after_exchange = next_device_total_weight - next_weight + weight
|
next_total_weight_after_exchange = next_device_total_weight - next_weight + weight
|
||||||
exchange_max_weight = max(
|
exchange_max_weight = max(cur_total_weight_after_exchange, next_total_weight_after_exchange)
|
||||||
cur_total_weight_after_exchange,
|
if exchange_max_weight < max_weight and (max_weight - exchange_max_weight) >= (
|
||||||
next_total_weight_after_exchange)
|
ave_workload * increment
|
||||||
if exchange_max_weight < max_weight and (
|
):
|
||||||
max_weight -
|
|
||||||
exchange_max_weight) >= (ave_workload * increment):
|
|
||||||
max_weight = exchange_max_weight
|
max_weight = exchange_max_weight
|
||||||
cur_exchange_index = index
|
cur_exchange_index = index
|
||||||
next_exchange_index = next_index
|
next_exchange_index = next_index
|
||||||
|
|
||||||
return cur_exchange_index, next_exchange_index
|
return cur_exchange_index, next_exchange_index
|
||||||
|
|
||||||
def expert_exchange_between_devices(self,
|
def expert_exchange_between_devices(
|
||||||
ave_workload,
|
self,
|
||||||
increment,
|
ave_workload,
|
||||||
cur_layer_result,
|
increment,
|
||||||
com_between_devices,
|
cur_layer_result,
|
||||||
num_redundancy_expert,
|
com_between_devices,
|
||||||
node_idx=0,
|
num_redundancy_expert,
|
||||||
per_node_device_num=0,
|
node_idx=0,
|
||||||
is_node_redundant=False):
|
per_node_device_num=0,
|
||||||
|
is_node_redundant=False,
|
||||||
|
):
|
||||||
if is_node_redundant:
|
if is_node_redundant:
|
||||||
cur_devices_result = cur_layer_result[node_idx *
|
cur_devices_result = cur_layer_result[node_idx * per_node_device_num : (node_idx + 1) * per_node_device_num]
|
||||||
per_node_device_num:
|
|
||||||
(node_idx + 1) *
|
|
||||||
per_node_device_num]
|
|
||||||
else:
|
else:
|
||||||
cur_devices_result = cur_layer_result
|
cur_devices_result = cur_layer_result
|
||||||
|
|
||||||
devices_total_weight = []
|
devices_total_weight = []
|
||||||
for device in cur_devices_result:
|
for device in cur_devices_result:
|
||||||
devices_total_weight.append(
|
devices_total_weight.append((device["total_load"], device["device_id"] - 1))
|
||||||
(device['total_load'], device['device_id'] - 1))
|
|
||||||
|
|
||||||
exchange_frequency = 100
|
exchange_frequency = 100
|
||||||
while exchange_frequency > 0:
|
while exchange_frequency > 0:
|
||||||
@@ -553,64 +568,81 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
exchange = False
|
exchange = False
|
||||||
for index in range(0, len(devices_total_weight) - 1):
|
for index in range(0, len(devices_total_weight) - 1):
|
||||||
min_weight_device_id = devices_total_weight[index][1]
|
min_weight_device_id = devices_total_weight[index][1]
|
||||||
if min_weight_device_id not in com_between_devices[
|
if min_weight_device_id not in com_between_devices[max_weight_device_id]:
|
||||||
max_weight_device_id]:
|
cur_exchanged_expert_id = list(com_between_devices[max_weight_device_id].values())
|
||||||
cur_exchanged_expert_id = list(
|
next_exchanged_expert_id = list(com_between_devices[min_weight_device_id].values())
|
||||||
com_between_devices[max_weight_device_id].values())
|
|
||||||
next_exchanged_expert_id = list(
|
|
||||||
com_between_devices[min_weight_device_id].values())
|
|
||||||
|
|
||||||
cur_exchange_index, next_exchange_index = self.two_device_exchange_experts(
|
cur_exchange_index, next_exchange_index = self.two_device_exchange_experts(
|
||||||
cur_layer_result[max_weight_device_id],
|
cur_layer_result[max_weight_device_id],
|
||||||
cur_layer_result[min_weight_device_id],
|
cur_layer_result[min_weight_device_id],
|
||||||
cur_exchanged_expert_id, next_exchanged_expert_id,
|
cur_exchanged_expert_id,
|
||||||
ave_workload, increment, num_redundancy_expert)
|
next_exchanged_expert_id,
|
||||||
|
ave_workload,
|
||||||
|
increment,
|
||||||
|
num_redundancy_expert,
|
||||||
|
)
|
||||||
|
|
||||||
if cur_exchange_index != -1:
|
if cur_exchange_index != -1:
|
||||||
self.exchange_expert(cur_exchange_index,
|
self.exchange_expert(
|
||||||
next_exchange_index,
|
cur_exchange_index,
|
||||||
max_weight_device_id,
|
next_exchange_index,
|
||||||
min_weight_device_id,
|
max_weight_device_id,
|
||||||
cur_layer_result,
|
min_weight_device_id,
|
||||||
com_between_devices)
|
cur_layer_result,
|
||||||
|
com_between_devices,
|
||||||
|
)
|
||||||
|
|
||||||
devices_total_weight[-1] = (
|
devices_total_weight[-1] = (
|
||||||
cur_layer_result[max_weight_device_id]
|
cur_layer_result[max_weight_device_id]["total_load"],
|
||||||
['total_load'], max_weight_device_id)
|
max_weight_device_id,
|
||||||
|
)
|
||||||
devices_total_weight[index] = (
|
devices_total_weight[index] = (
|
||||||
cur_layer_result[min_weight_device_id]
|
cur_layer_result[min_weight_device_id]["total_load"],
|
||||||
['total_load'], min_weight_device_id)
|
min_weight_device_id,
|
||||||
|
)
|
||||||
exchange = True
|
exchange = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if not exchange:
|
if not exchange:
|
||||||
break
|
break
|
||||||
|
|
||||||
def exchange_experts(self, layer_result, layer_com_between_devices,
|
def exchange_experts(
|
||||||
num_nodes, device_num, is_node_redundant,
|
self,
|
||||||
ave_workload, increment, num_redundancy_expert,
|
layer_result,
|
||||||
org_deployment):
|
layer_com_between_devices,
|
||||||
|
num_nodes,
|
||||||
|
device_num,
|
||||||
|
is_node_redundant,
|
||||||
|
ave_workload,
|
||||||
|
increment,
|
||||||
|
num_redundancy_expert,
|
||||||
|
org_deployment,
|
||||||
|
):
|
||||||
global_deployment = []
|
global_deployment = []
|
||||||
|
|
||||||
if is_node_redundant:
|
if is_node_redundant:
|
||||||
per_node_device_num = self.safe_exact_divide(device_num, num_nodes)
|
per_node_device_num = self.safe_exact_divide(device_num, num_nodes)
|
||||||
for node_idx in range(num_nodes):
|
for node_idx in range(num_nodes):
|
||||||
self.expert_exchange_between_devices(
|
self.expert_exchange_between_devices(
|
||||||
ave_workload, increment, layer_result,
|
ave_workload,
|
||||||
layer_com_between_devices, num_redundancy_expert, node_idx,
|
increment,
|
||||||
per_node_device_num, is_node_redundant)
|
layer_result,
|
||||||
|
layer_com_between_devices,
|
||||||
|
num_redundancy_expert,
|
||||||
|
node_idx,
|
||||||
|
per_node_device_num,
|
||||||
|
is_node_redundant,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.expert_exchange_between_devices(ave_workload, increment,
|
self.expert_exchange_between_devices(
|
||||||
layer_result,
|
ave_workload, increment, layer_result, layer_com_between_devices, num_redundancy_expert
|
||||||
layer_com_between_devices,
|
)
|
||||||
num_redundancy_expert)
|
|
||||||
|
|
||||||
max_workload = 0
|
max_workload = 0
|
||||||
for box in layer_result:
|
for box in layer_result:
|
||||||
global_deployment.append(box['assigned_experts'])
|
global_deployment.append(box["assigned_experts"])
|
||||||
if max_workload < box['total_load']:
|
if max_workload < box["total_load"]:
|
||||||
max_workload = box['total_load']
|
max_workload = box["total_load"]
|
||||||
|
|
||||||
global_deployment = np.array(global_deployment)
|
global_deployment = np.array(global_deployment)
|
||||||
|
|
||||||
@@ -626,16 +658,11 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
return count
|
return count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def constraint_expert_local_exchange(current_expert_table,
|
def constraint_expert_local_exchange(current_expert_table, global_deployment):
|
||||||
global_deployment):
|
|
||||||
for layer_id in range(len(global_deployment)):
|
for layer_id in range(len(global_deployment)):
|
||||||
for card_id in range(len(global_deployment[layer_id])):
|
for card_id in range(len(global_deployment[layer_id])):
|
||||||
current_list = [
|
current_list = [int(x) for x in current_expert_table[layer_id][card_id]]
|
||||||
int(x) for x in current_expert_table[layer_id][card_id]
|
new_list = [int(x) for x in global_deployment[layer_id][card_id]]
|
||||||
]
|
|
||||||
new_list = [
|
|
||||||
int(x) for x in global_deployment[layer_id][card_id]
|
|
||||||
]
|
|
||||||
num = len(new_list)
|
num = len(new_list)
|
||||||
|
|
||||||
new_index = [-1] * num
|
new_index = [-1] * num
|
||||||
@@ -645,8 +672,7 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
for i in range(num):
|
for i in range(num):
|
||||||
flag = True
|
flag = True
|
||||||
for j in range(num):
|
for j in range(num):
|
||||||
if new_list[i] == current_list[j] and new_index[
|
if new_list[i] == current_list[j] and new_index[j] == -1:
|
||||||
j] == -1:
|
|
||||||
new_index[j] = 0
|
new_index[j] = 0
|
||||||
new_result[j] = current_list[j]
|
new_result[j] = current_list[j]
|
||||||
flag = False
|
flag = False
|
||||||
@@ -664,25 +690,17 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
|
|
||||||
return global_deployment
|
return global_deployment
|
||||||
|
|
||||||
def rebalance_experts(self,
|
def rebalance_experts(self, current_expert_table, expert_workload, is_node_redundant=False, increment=0.01):
|
||||||
current_expert_table,
|
|
||||||
expert_workload,
|
|
||||||
is_node_redundant=False,
|
|
||||||
increment=0.01):
|
|
||||||
info = DynamicTable()
|
info = DynamicTable()
|
||||||
info.workload_table = expert_workload.numpy()
|
info.workload_table = expert_workload.numpy()
|
||||||
info.placement_table = current_expert_table.numpy()
|
info.placement_table = current_expert_table.numpy()
|
||||||
assert info.workload_table is not None
|
assert info.workload_table is not None
|
||||||
layer_num, num_npus, experts_per_npu = info.workload_table.shape
|
layer_num, num_npus, experts_per_npu = info.workload_table.shape
|
||||||
expert_ids, counts = np.unique(info.placement_table[0],
|
expert_ids, counts = np.unique(info.placement_table[0], return_counts=True)
|
||||||
return_counts=True)
|
|
||||||
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
|
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
|
||||||
num_original_expert = len(expert_ids)
|
num_original_expert = len(expert_ids)
|
||||||
layer_workloads = self.add_redundant(info.placement_table,
|
layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert)
|
||||||
info.workload_table,
|
max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num)
|
||||||
num_original_expert)
|
|
||||||
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
|
|
||||||
info.workload_table, layer_num)
|
|
||||||
npu_heat_all_origin = sum(max_heat_per_layer_before)
|
npu_heat_all_origin = sum(max_heat_per_layer_before)
|
||||||
|
|
||||||
num_node = self.safe_exact_divide(num_npus, 8)
|
num_node = self.safe_exact_divide(num_npus, 8)
|
||||||
@@ -700,14 +718,13 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
|
|
||||||
if num_npus < num_redundancy_expert:
|
if num_npus < num_redundancy_expert:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})"
|
"The number of NPUs "
|
||||||
|
f"({num_npus}) must be greater than or equal to the number of redundant experts "
|
||||||
|
f"({num_redundancy_expert})"
|
||||||
)
|
)
|
||||||
|
|
||||||
global_deployment: list[list[list[int]]] = [[[]
|
global_deployment: list[list[list[int]]] = [[[] for _ in range(num_npus)] for _ in range(layer_num)]
|
||||||
for _ in range(num_npus)]
|
layer_initial_imbalance = self.calculate_initial_imbalance(info.placement_table, layer_workloads)
|
||||||
for _ in range(layer_num)]
|
|
||||||
layer_initial_imbalance = self.calculate_initial_imbalance(
|
|
||||||
info.placement_table, layer_workloads)
|
|
||||||
max_heat_per_layer_after = np.zeros([layer_num])
|
max_heat_per_layer_after = np.zeros([layer_num])
|
||||||
sum_num = 0
|
sum_num = 0
|
||||||
for layer in range(layer_num):
|
for layer in range(layer_num):
|
||||||
@@ -715,8 +732,7 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
global_deployment[layer] = info.placement_table[layer]
|
global_deployment[layer] = info.placement_table[layer]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ave_workload = self.safe_divide(np.sum(layer_workloads[layer]),
|
ave_workload = self.safe_divide(np.sum(layer_workloads[layer]), num_npus)
|
||||||
num_npus)
|
|
||||||
|
|
||||||
rendun_pos: list[list[int]] = [[] for _ in range(num_npus)]
|
rendun_pos: list[list[int]] = [[] for _ in range(num_npus)]
|
||||||
existing_experts = set()
|
existing_experts = set()
|
||||||
@@ -729,30 +745,37 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
rendun_pos[device_id].append(index)
|
rendun_pos[device_id].append(index)
|
||||||
|
|
||||||
result, max_workload, com_between_devices = self.redundant_expert_deployment(
|
result, max_workload, com_between_devices = self.redundant_expert_deployment(
|
||||||
layer_workloads[layer], info.placement_table[layer],
|
layer_workloads[layer],
|
||||||
expert_from_device[layer], num_node, is_node_redundant,
|
info.placement_table[layer],
|
||||||
rendun_pos)
|
expert_from_device[layer],
|
||||||
|
num_node,
|
||||||
|
is_node_redundant,
|
||||||
|
rendun_pos,
|
||||||
|
)
|
||||||
|
|
||||||
global_deployment[layer], new_max_workload = self.exchange_experts(
|
global_deployment[layer], new_max_workload = self.exchange_experts(
|
||||||
result, com_between_devices, num_node, num_npus,
|
result,
|
||||||
is_node_redundant, ave_workload, increment,
|
com_between_devices,
|
||||||
num_redundancy_expert, info.placement_table[layer])
|
num_node,
|
||||||
|
num_npus,
|
||||||
|
is_node_redundant,
|
||||||
|
ave_workload,
|
||||||
|
increment,
|
||||||
|
num_redundancy_expert,
|
||||||
|
info.placement_table[layer],
|
||||||
|
)
|
||||||
|
|
||||||
for device_id in range(num_npus):
|
for device_id in range(num_npus):
|
||||||
com_between_devices[device_id] = {
|
com_between_devices[device_id] = {key: value for key, value in com_between_devices[device_id].items()}
|
||||||
key: value
|
|
||||||
for key, value in com_between_devices[device_id].items()
|
|
||||||
}
|
|
||||||
sum_num += self.count_elements(com_between_devices[device_id])
|
sum_num += self.count_elements(com_between_devices[device_id])
|
||||||
|
|
||||||
max_heat_per_layer_after[layer] = max(
|
max_heat_per_layer_after[layer] = max(result, key=lambda x: x["total_load"])["total_load"]
|
||||||
result, key=lambda x: x['total_load'])['total_load']
|
|
||||||
|
|
||||||
layer_changed_ratio = []
|
layer_changed_ratio = []
|
||||||
for layer_idx in range(layer_num):
|
for layer_idx in range(layer_num):
|
||||||
layer_changed_ratio.append(
|
layer_changed_ratio.append(
|
||||||
self.safe_divide(max_heat_per_layer_after[layer_idx],
|
self.safe_divide(max_heat_per_layer_after[layer_idx], max_heat_per_layer_before[layer_idx])
|
||||||
max_heat_per_layer_before[layer_idx]))
|
)
|
||||||
|
|
||||||
per_layer_priority = np.argsort(layer_changed_ratio)
|
per_layer_priority = np.argsort(layer_changed_ratio)
|
||||||
npu_heat_all_after = sum(max_heat_per_layer_after)
|
npu_heat_all_after = sum(max_heat_per_layer_after)
|
||||||
@@ -761,8 +784,6 @@ class SwiftBalanceEplb(EplbPolicy):
|
|||||||
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
|
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
|
||||||
change = 1
|
change = 1
|
||||||
|
|
||||||
new_global_deployment = self.constraint_expert_local_exchange(
|
new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment)
|
||||||
current_expert_table, global_deployment)
|
|
||||||
|
|
||||||
return change, per_layer_priority, np.array(
|
return change, per_layer_priority, np.array(new_global_deployment).tolist()
|
||||||
new_global_deployment).tolist()
|
|
||||||
|
|||||||
@@ -26,9 +26,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
|||||||
|
|
||||||
|
|
||||||
class EplbUpdator:
|
class EplbUpdator:
|
||||||
|
def __init__(self, eplb_config, loader, eplb_process: EplbProcess, process):
|
||||||
def __init__(self, eplb_config, loader, eplb_process: EplbProcess,
|
|
||||||
process):
|
|
||||||
self.eplb_config = eplb_config
|
self.eplb_config = eplb_config
|
||||||
self.init_eplb(self.eplb_config.expert_map_path, process)
|
self.init_eplb(self.eplb_config.expert_map_path, process)
|
||||||
self.eplb_loader = loader
|
self.eplb_loader = loader
|
||||||
@@ -43,9 +41,7 @@ class EplbUpdator:
|
|||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
self.device = local_load.device
|
self.device = local_load.device
|
||||||
shape = (self.world_size, *local_load.shape)
|
shape = (self.world_size, *local_load.shape)
|
||||||
self._gather_buffer = torch.empty(shape,
|
self._gather_buffer = torch.empty(shape, dtype=local_load.dtype, device=self.device)
|
||||||
dtype=local_load.dtype,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
def init_eplb(self, expert_map_path, process):
|
def init_eplb(self, expert_map_path, process):
|
||||||
self.rank_id = dist.get_rank()
|
self.rank_id = dist.get_rank()
|
||||||
@@ -72,52 +68,49 @@ class EplbUpdator:
|
|||||||
|
|
||||||
self.process = process
|
self.process = process
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
|
||||||
f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
|
|
||||||
|
|
||||||
def update_iteration(self):
|
def update_iteration(self):
|
||||||
self.cur_iterations += 1
|
self.cur_iterations += 1
|
||||||
if self.cur_iterations == (self.expert_heat_collection_interval + \
|
if self.cur_iterations == (
|
||||||
self.algorithm_execution_interval + self.num_moe_layers):
|
self.expert_heat_collection_interval + self.algorithm_execution_interval + self.num_moe_layers
|
||||||
|
):
|
||||||
logger.info("Finish expert parallel load balancing.")
|
logger.info("Finish expert parallel load balancing.")
|
||||||
if self.expert_map_record_path is not None:
|
if self.expert_map_record_path is not None:
|
||||||
self.adaptor._export_tensor_to_file(
|
self.adaptor._export_tensor_to_file(self.shared_dict["expert_maps"], self.expert_map_record_path)
|
||||||
self.shared_dict["expert_maps"],
|
|
||||||
self.expert_map_record_path)
|
|
||||||
|
|
||||||
self.adaptor.model.clear_all_moe_loads()
|
self.adaptor.model.clear_all_moe_loads()
|
||||||
self.cur_iterations = 0
|
self.cur_iterations = 0
|
||||||
|
|
||||||
def get_update_info_flag(self):
|
def get_update_info_flag(self):
|
||||||
return self.cur_iterations == (self.expert_heat_collection_interval +
|
return self.cur_iterations == (self.expert_heat_collection_interval + self.algorithm_execution_interval - 1)
|
||||||
self.algorithm_execution_interval - 1)
|
|
||||||
|
|
||||||
def wakeup_eplb_worker_flag(self):
|
def wakeup_eplb_worker_flag(self):
|
||||||
return self.cur_iterations == (self.expert_heat_collection_interval -
|
return self.cur_iterations == (self.expert_heat_collection_interval - 1)
|
||||||
1)
|
|
||||||
|
|
||||||
def update_expert_weight_flag(self):
|
def update_expert_weight_flag(self):
|
||||||
weight_update_counter = self.cur_iterations - (
|
weight_update_counter = self.cur_iterations - (
|
||||||
self.expert_heat_collection_interval +
|
self.expert_heat_collection_interval + self.algorithm_execution_interval
|
||||||
self.algorithm_execution_interval)
|
)
|
||||||
return (weight_update_counter >= 0
|
return weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers
|
||||||
and weight_update_counter < self.num_moe_layers)
|
|
||||||
|
|
||||||
def wakeup_eplb_worker(self):
|
def wakeup_eplb_worker(self):
|
||||||
self.eplb_process.planner_q.put(1)
|
self.eplb_process.planner_q.put(1)
|
||||||
|
|
||||||
def forward_before(self):
|
def forward_before(self):
|
||||||
if self.update_expert_weight_flag():
|
if self.update_expert_weight_flag():
|
||||||
(expert_send_info, expert_recv_info, updated_expert_map,
|
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(
|
||||||
log2phy_map, layer_id) = self.update_info_all.pop(0)
|
0
|
||||||
|
)
|
||||||
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
|
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
|
||||||
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
|
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
|
||||||
updated_expert_map_this_rank = torch.from_numpy(
|
updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map))
|
||||||
numpy.array(updated_expert_map))
|
|
||||||
self.eplb_loader.generate_expert_d2d_transfer_task(
|
self.eplb_loader.generate_expert_d2d_transfer_task(
|
||||||
expert_send_info, expert_recv_info,
|
expert_send_info,
|
||||||
|
expert_recv_info,
|
||||||
updated_expert_map_this_rank,
|
updated_expert_map_this_rank,
|
||||||
layer_id + self.adaptor.num_dense_layers)
|
layer_id + self.adaptor.num_dense_layers,
|
||||||
|
)
|
||||||
|
|
||||||
# set asynchronous stream for d2d expert weight update
|
# set asynchronous stream for d2d expert weight update
|
||||||
self.reqs = []
|
self.reqs = []
|
||||||
@@ -133,8 +126,7 @@ class EplbUpdator:
|
|||||||
self.compute_and_set_moe_load()
|
self.compute_and_set_moe_load()
|
||||||
self.wakeup_eplb_worker()
|
self.wakeup_eplb_worker()
|
||||||
|
|
||||||
if self.update_expert_weight_flag(
|
if self.update_expert_weight_flag() and self.expert_map_record_path is None:
|
||||||
) and self.expert_map_record_path is None:
|
|
||||||
self.eplb_loader.update_expert_map_and_weight(self.reqs)
|
self.eplb_loader.update_expert_map_and_weight(self.reqs)
|
||||||
|
|
||||||
self.update_iteration()
|
self.update_iteration()
|
||||||
@@ -145,9 +137,7 @@ class EplbUpdator:
|
|||||||
|
|
||||||
moe_load = self._gather_buffer.permute(1, 0, 2)
|
moe_load = self._gather_buffer.permute(1, 0, 2)
|
||||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||||
logger.debug(
|
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
|
||||||
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
self.compute_moe_imbalance(moe_load)
|
self.compute_moe_imbalance(moe_load)
|
||||||
@@ -156,7 +146,6 @@ class EplbUpdator:
|
|||||||
return moe_load
|
return moe_load
|
||||||
|
|
||||||
def compute_moe_imbalance(self, moe_load: torch.Tensor):
|
def compute_moe_imbalance(self, moe_load: torch.Tensor):
|
||||||
|
|
||||||
self.moe_imbalance_dict.clear()
|
self.moe_imbalance_dict.clear()
|
||||||
|
|
||||||
layer_card_load = moe_load.sum(dim=-1).cpu().float()
|
layer_card_load = moe_load.sum(dim=-1).cpu().float()
|
||||||
@@ -169,13 +158,11 @@ class EplbUpdator:
|
|||||||
|
|
||||||
moe_load_imbalance = max_load / (mean_load + 1e-6)
|
moe_load_imbalance = max_load / (mean_load + 1e-6)
|
||||||
|
|
||||||
logger.debug(f"[ModelRunner][MOE_load_stats][Layer {layer_idx}] "
|
logger.debug(f"[ModelRunner][MOE_load_stats][Layer {layer_idx}] PAR={moe_load_imbalance:.4f}")
|
||||||
f"PAR={moe_load_imbalance:.4f}")
|
|
||||||
|
|
||||||
self.moe_imbalance_dict[layer_idx] = moe_load_imbalance
|
self.moe_imbalance_dict[layer_idx] = moe_load_imbalance
|
||||||
|
|
||||||
def summarize_moe_imbalance(self):
|
def summarize_moe_imbalance(self):
|
||||||
|
|
||||||
values = list(self.moe_imbalance_dict.values())
|
values = list(self.moe_imbalance_dict.values())
|
||||||
if not values:
|
if not values:
|
||||||
logger.info("[MOE_load_stats] No data available.")
|
logger.info("[MOE_load_stats] No data available.")
|
||||||
@@ -191,11 +178,10 @@ class EplbUpdator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def warm_up_eplb(self):
|
def warm_up_eplb(self):
|
||||||
|
|
||||||
self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map()
|
self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map()
|
||||||
self.compute_and_set_moe_load()
|
self.compute_and_set_moe_load()
|
||||||
|
|
||||||
src_tensor = torch.empty((1, ), device=self.device)
|
src_tensor = torch.empty((1,), device=self.device)
|
||||||
self_rank = dist.get_rank()
|
self_rank = dist.get_rank()
|
||||||
|
|
||||||
comm_op_list = []
|
comm_op_list = []
|
||||||
|
|||||||
@@ -30,33 +30,30 @@ def get_log2phy_map(self, layer_id):
|
|||||||
|
|
||||||
def get_all_expert_map(self, num_moe_layers):
|
def get_all_expert_map(self, num_moe_layers):
|
||||||
all_loads = []
|
all_loads = []
|
||||||
num_dense_layers = self.num_dense_layers if hasattr(
|
num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0
|
||||||
self, "num_dense_layers") else 0
|
|
||||||
for layer_id in range(num_moe_layers):
|
for layer_id in range(num_moe_layers):
|
||||||
load_tensor = self.get_expert_map(
|
load_tensor = self.get_expert_map(layer_id + num_dense_layers) # (num_experts_per_layer,)
|
||||||
layer_id + num_dense_layers) # (num_experts_per_layer,)
|
|
||||||
all_loads.append(load_tensor)
|
all_loads.append(load_tensor)
|
||||||
|
|
||||||
return torch.stack(all_loads, dim=0)
|
return torch.stack(all_loads, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def get_all_moe_loads(self):
|
def get_all_moe_loads(self):
|
||||||
num_dense_layers = self.num_dense_layers if hasattr(
|
num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0
|
||||||
self, "num_dense_layers") else 0
|
|
||||||
all_moe_loads = torch.stack(
|
all_moe_loads = torch.stack(
|
||||||
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
|
[
|
||||||
for layer_id in range(self.num_moe_layers)],
|
self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load
|
||||||
dim=0
|
for layer_id in range(self.num_moe_layers)
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
)
|
)
|
||||||
return all_moe_loads
|
return all_moe_loads
|
||||||
|
|
||||||
|
|
||||||
def clear_all_moe_loads(self):
|
def clear_all_moe_loads(self):
|
||||||
num_dense_layers = self.num_dense_layers if hasattr(
|
num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0
|
||||||
self, "num_dense_layers") else 0
|
|
||||||
for layer_id in range(self.num_moe_layers):
|
for layer_id in range(self.num_moe_layers):
|
||||||
self.model.layers[layer_id +
|
self.model.layers[layer_id + num_dense_layers].mlp.experts.clear_moe_load()
|
||||||
num_dense_layers].mlp.experts.clear_moe_load()
|
|
||||||
|
|
||||||
|
|
||||||
def model_register(model, model_config):
|
def model_register(model, model_config):
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
from .netloader_pg import (destroy_stateless_process_group,
|
from .netloader_pg import destroy_stateless_process_group, stateless_init_process_group
|
||||||
stateless_init_process_group)
|
|
||||||
|
|
||||||
|
|
||||||
class P2PLoad:
|
class P2PLoad:
|
||||||
@@ -56,9 +55,7 @@ class P2PLoad:
|
|||||||
- The model if loading is successful, otherwise None.
|
- The model if loading is successful, otherwise None.
|
||||||
"""
|
"""
|
||||||
model_device = next(model.parameters()).device
|
model_device = next(model.parameters()).device
|
||||||
logger.info(
|
logger.info(f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||||
f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
|
||||||
)
|
|
||||||
receiver_pg = None
|
receiver_pg = None
|
||||||
loaded_model = None
|
loaded_model = None
|
||||||
try:
|
try:
|
||||||
@@ -67,15 +64,13 @@ class P2PLoad:
|
|||||||
port=self.source_port,
|
port=self.source_port,
|
||||||
rank=0,
|
rank=0,
|
||||||
world_size=2,
|
world_size=2,
|
||||||
group_name='netloader',
|
group_name="netloader",
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||||
f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
|
||||||
)
|
|
||||||
logger.info(f"Model device: {model_device}")
|
logger.info(f"Model device: {model_device}")
|
||||||
|
|
||||||
trans_stream = torch_npu.npu.Stream()
|
trans_stream = torch_npu.npu.Stream()
|
||||||
@@ -84,14 +79,11 @@ class P2PLoad:
|
|||||||
if len(param.shape) == 0:
|
if len(param.shape) == 0:
|
||||||
continue
|
continue
|
||||||
receiver_pg.recv([param], 1, 0).wait()
|
receiver_pg.recv([param], 1, 0).wait()
|
||||||
torch.distributed.barrier(group=receiver_pg,
|
torch.distributed.barrier(group=receiver_pg, device_ids=[model_device.index])
|
||||||
device_ids=[model_device.index])
|
|
||||||
|
|
||||||
torch_npu.npu.synchronize(trans_stream)
|
torch_npu.npu.synchronize(trans_stream)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||||
f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
|
||||||
)
|
|
||||||
loaded_model = model
|
loaded_model = model
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to recv model: {}".format(e))
|
logger.error("Failed to recv model: {}".format(e))
|
||||||
@@ -129,9 +121,7 @@ class P2PSend:
|
|||||||
"""
|
"""
|
||||||
model_device = next(model.parameters()).device
|
model_device = next(model.parameters()).device
|
||||||
torch.npu.set_device(model_device)
|
torch.npu.set_device(model_device)
|
||||||
logger.info(
|
logger.info(f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||||
f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
|
||||||
)
|
|
||||||
sender_pg = None
|
sender_pg = None
|
||||||
try:
|
try:
|
||||||
sender_pg = stateless_init_process_group(
|
sender_pg = stateless_init_process_group(
|
||||||
@@ -139,14 +129,10 @@ class P2PSend:
|
|||||||
port=self.listen_port,
|
port=self.listen_port,
|
||||||
rank=1,
|
rank=1,
|
||||||
world_size=2,
|
world_size=2,
|
||||||
group_name='netloader',
|
group_name="netloader",
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
|
||||||
)
|
)
|
||||||
|
logger.info(f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||||
|
logger.info(f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||||
logger.info(f"Model device: {model_device}")
|
logger.info(f"Model device: {model_device}")
|
||||||
|
|
||||||
trans_stream = torch_npu.npu.Stream()
|
trans_stream = torch_npu.npu.Stream()
|
||||||
@@ -155,16 +141,12 @@ class P2PSend:
|
|||||||
if "aclnn_input_scale" in name:
|
if "aclnn_input_scale" in name:
|
||||||
continue
|
continue
|
||||||
if name in int8_params:
|
if name in int8_params:
|
||||||
sender_pg.send([int8_params[name].to(model_device)], 0,
|
sender_pg.send([int8_params[name].to(model_device)], 0, 0).wait()
|
||||||
0).wait()
|
|
||||||
else:
|
else:
|
||||||
sender_pg.send([param.contiguous()], 0, 0).wait()
|
sender_pg.send([param.contiguous()], 0, 0).wait()
|
||||||
torch.distributed.barrier(group=sender_pg,
|
torch.distributed.barrier(group=sender_pg, device_ids=[model_device.index])
|
||||||
device_ids=[model_device.index])
|
|
||||||
torch_npu.npu.synchronize(trans_stream)
|
torch_npu.npu.synchronize(trans_stream)
|
||||||
logger.info(
|
logger.info(f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||||
f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
if sender_pg:
|
if sender_pg:
|
||||||
destroy_stateless_process_group(sender_pg)
|
destroy_stateless_process_group(sender_pg)
|
||||||
|
|||||||
@@ -17,16 +17,13 @@
|
|||||||
import gc
|
import gc
|
||||||
import ipaddress
|
import ipaddress
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from torch._C._distributed_c10d import (_DEFAULT_PG_TIMEOUT,
|
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT, _register_process_group, _unregister_process_group
|
||||||
_register_process_group,
|
|
||||||
_unregister_process_group)
|
|
||||||
from torch.distributed import ProcessGroup, is_hccl_available
|
from torch.distributed import ProcessGroup, is_hccl_available
|
||||||
from torch.distributed.distributed_c10d import (Backend, BackendConfig,
|
from torch.distributed.distributed_c10d import Backend, BackendConfig, PrefixStore, _world
|
||||||
PrefixStore, _world)
|
|
||||||
from torch.distributed.rendezvous import rendezvous
|
from torch.distributed.rendezvous import rendezvous
|
||||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -39,7 +36,7 @@ def stateless_init_process_group(
|
|||||||
rank: int,
|
rank: int,
|
||||||
timeout: timedelta = _DEFAULT_PG_TIMEOUT,
|
timeout: timedelta = _DEFAULT_PG_TIMEOUT,
|
||||||
group_name: str = "",
|
group_name: str = "",
|
||||||
pg_options: Optional[Any] = None,
|
pg_options: Any | None = None,
|
||||||
) -> ProcessGroup:
|
) -> ProcessGroup:
|
||||||
"""
|
"""
|
||||||
Initializes a stateless process group.
|
Initializes a stateless process group.
|
||||||
@@ -57,7 +54,8 @@ def stateless_init_process_group(
|
|||||||
ProcessGroup: The initialized process group.
|
ProcessGroup: The initialized process group.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If world_size is not positive, or if rank is not within [0, world_size - 1], or if HCCL is unavailable.
|
RuntimeError: If world_size is not positive, or if rank is not within
|
||||||
|
[0, world_size - 1], or if HCCL is unavailable.
|
||||||
TypeError: If timeout is not a timedelta type.
|
TypeError: If timeout is not a timedelta type.
|
||||||
ValueError: If group_name already exists.
|
ValueError: If group_name already exists.
|
||||||
"""
|
"""
|
||||||
@@ -67,21 +65,18 @@ def stateless_init_process_group(
|
|||||||
raise RuntimeError("world_size must be positive")
|
raise RuntimeError("world_size must be positive")
|
||||||
# Check if rank is within [0, world_size - 1]
|
# Check if rank is within [0, world_size - 1]
|
||||||
if not (rank >= 0 and rank <= world_size - 1):
|
if not (rank >= 0 and rank <= world_size - 1):
|
||||||
raise RuntimeError(
|
raise RuntimeError("rank should be a number between 0 and ``world_size``-1")
|
||||||
"rank should be a number between 0 and ``world_size``-1")
|
|
||||||
# Check if HCCL is available
|
# Check if HCCL is available
|
||||||
if not is_hccl_available():
|
if not is_hccl_available():
|
||||||
raise RuntimeError("HCCL is not available")
|
raise RuntimeError("HCCL is not available")
|
||||||
# Check if timeout is a timedelta type
|
# Check if timeout is a timedelta type
|
||||||
if not isinstance(timeout, timedelta):
|
if not isinstance(timeout, timedelta):
|
||||||
raise TypeError(
|
raise TypeError(f"Expected timeout argument to be of type datetime.timedelta, got {timeout}")
|
||||||
f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
|
|
||||||
)
|
|
||||||
# Check if group_name already exists
|
# Check if group_name already exists
|
||||||
if group_name in _world.pg_names.values():
|
if group_name in _world.pg_names.values():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The specified group name {group_name} has already been "
|
f"The specified group name {group_name} has already been created, please use a different group name"
|
||||||
"created, please use a different group name")
|
)
|
||||||
|
|
||||||
# Function to check if an IPv6 address is valid
|
# Function to check if an IPv6 address is valid
|
||||||
def is_valid_ipv6_address(address: str) -> bool:
|
def is_valid_ipv6_address(address: str) -> bool:
|
||||||
@@ -101,10 +96,9 @@ def stateless_init_process_group(
|
|||||||
# Get initialization method
|
# Get initialization method
|
||||||
init_method = get_tcp_uri(host, port)
|
init_method = get_tcp_uri(host, port)
|
||||||
# Create Backend object
|
# Create Backend object
|
||||||
backend = Backend('hccl')
|
backend = Backend("hccl")
|
||||||
# Use rendezvous function to get store, rank, and world_size
|
# Use rendezvous function to get store, rank, and world_size
|
||||||
store, rank, world_size = next(
|
store, rank, world_size = next(rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
|
||||||
|
|
||||||
# Set timeout for store
|
# Set timeout for store
|
||||||
store.set_timeout(timeout)
|
store.set_timeout(timeout)
|
||||||
@@ -125,9 +119,7 @@ def stateless_init_process_group(
|
|||||||
pg._set_default_backend(Backend.backend_type_map[backend])
|
pg._set_default_backend(Backend.backend_type_map[backend])
|
||||||
|
|
||||||
# Check if pg_options is None or not of type ProcessGroupHCCL.Options
|
# Check if pg_options is None or not of type ProcessGroupHCCL.Options
|
||||||
if pg_options is None or not isinstance(
|
if pg_options is None or not isinstance(pg_options, torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
|
||||||
pg_options,
|
|
||||||
torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
|
|
||||||
pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
|
pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
|
||||||
# Set attributes for pg_options
|
# Set attributes for pg_options
|
||||||
pg_options.is_high_priority_stream = False
|
pg_options.is_high_priority_stream = False
|
||||||
@@ -135,8 +127,7 @@ def stateless_init_process_group(
|
|||||||
pg_options.global_ranks_in_group = []
|
pg_options.global_ranks_in_group = []
|
||||||
pg_options.group_id = f"{init_method}/{group_name}/"
|
pg_options.group_id = f"{init_method}/{group_name}/"
|
||||||
# Create ProcessGroupHCCL object
|
# Create ProcessGroupHCCL object
|
||||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, pg_options)
|
||||||
pg_options)
|
|
||||||
# Set sequence number for backend_class
|
# Set sequence number for backend_class
|
||||||
backend_class._set_sequence_number_for_group()
|
backend_class._set_sequence_number_for_group()
|
||||||
# Set backend_type
|
# Set backend_type
|
||||||
@@ -176,9 +167,10 @@ def destroy_stateless_process_group(pg: ProcessGroup, manual_gc: bool = False):
|
|||||||
_world.pg_group_ranks.pop(pg, None)
|
_world.pg_group_ranks.pop(pg, None)
|
||||||
_world.pg_backend_config.pop(pg, None)
|
_world.pg_backend_config.pop(pg, None)
|
||||||
# Check if pg is in keys of _world.pg_coalesce_state
|
# Check if pg is in keys of _world.pg_coalesce_state
|
||||||
if pg in _world.pg_coalesce_state.keys():
|
if pg in _world.pg_coalesce_state:
|
||||||
logger.warning("Some coalesced collectives haven't been launched when "
|
logger.warning(
|
||||||
"ProcessGroup is destroyed. They will be cleaned.")
|
"Some coalesced collectives haven't been launched when ProcessGroup is destroyed. They will be cleaned."
|
||||||
|
)
|
||||||
del _world.pg_coalesce_state[pg]
|
del _world.pg_coalesce_state[pg]
|
||||||
# Unregister the process group
|
# Unregister the process group
|
||||||
_unregister_process_group(pg.group_name)
|
_unregister_process_group(pg.group_name)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
import threading
|
import threading
|
||||||
from typing import List, Optional, Tuple
|
from contextlib import suppress
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -32,8 +32,7 @@ class ElasticClient:
|
|||||||
Class for handling the client-side logic of Netloader of models.
|
Class for handling the client-side logic of Netloader of models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sources: list[str], device_id: int, model_path: str,
|
def __init__(self, sources: list[str], device_id: int, model_path: str, tp: int, pp: int):
|
||||||
tp: int, pp: int):
|
|
||||||
"""
|
"""
|
||||||
Initializes the ElasticClient instance.
|
Initializes the ElasticClient instance.
|
||||||
|
|
||||||
@@ -50,14 +49,14 @@ class ElasticClient:
|
|||||||
self.tp = tp
|
self.tp = tp
|
||||||
self.pp = pp
|
self.pp = pp
|
||||||
|
|
||||||
self.s: Optional[socket.socket] = None
|
self.s: socket.socket | None = None
|
||||||
self.ack: Optional[Tuple[str, int]] = None
|
self.ack: tuple[str, int] | None = None
|
||||||
self.server_addr: Optional[str] = None
|
self.server_addr: str | None = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: int | None = None
|
||||||
|
|
||||||
for source in self.sources:
|
for source in self.sources:
|
||||||
try:
|
try:
|
||||||
ip, port_str = source.split(':')
|
ip, port_str = source.split(":")
|
||||||
port = int(port_str)
|
port = int(port_str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"IP format error: {source}, detail: {e}")
|
logger.info(f"IP format error: {source}, detail: {e}")
|
||||||
@@ -68,13 +67,9 @@ class ElasticClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
logger.info(
|
logger.info(f"Start connection to server: {self.server_addr}:{self.server_port}")
|
||||||
f"Start connection to server: {self.server_addr}:{self.server_port}"
|
|
||||||
)
|
|
||||||
sock.connect((self.server_addr, self.server_port))
|
sock.connect((self.server_addr, self.server_port))
|
||||||
logger.info(
|
logger.info(f"Finish connection to server: {self.server_addr}:{self.server_port}")
|
||||||
f"Finish connection to server: {self.server_addr}:{self.server_port}"
|
|
||||||
)
|
|
||||||
sock.settimeout(60)
|
sock.settimeout(60)
|
||||||
|
|
||||||
self.s = sock
|
self.s = sock
|
||||||
@@ -83,10 +78,8 @@ class ElasticClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Connect to {source} fails, detail: {e}")
|
logger.error(f"Connect to {source} fails, detail: {e}")
|
||||||
if sock is not None:
|
if sock is not None:
|
||||||
try:
|
with suppress(Exception):
|
||||||
sock.close()
|
sock.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self.s = None
|
self.s = None
|
||||||
self.ack = None
|
self.ack = None
|
||||||
self.server_addr = None
|
self.server_addr = None
|
||||||
@@ -120,10 +113,8 @@ class ElasticClient:
|
|||||||
"""
|
"""
|
||||||
Destructor method to ensure socket is closed.
|
Destructor method to ensure socket is closed.
|
||||||
"""
|
"""
|
||||||
try:
|
with suppress(Exception):
|
||||||
self.close()
|
self.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def send_str(self, data_str: str) -> None:
|
def send_str(self, data_str: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -151,8 +142,7 @@ class ElasticClient:
|
|||||||
data_str = self.s.recv(buffer_size).decode("utf-8")
|
data_str = self.s.recv(buffer_size).decode("utf-8")
|
||||||
return data_str
|
return data_str
|
||||||
|
|
||||||
def register(self, device_id: int, model_path: str, tp: int,
|
def register(self, device_id: int, model_path: str, tp: int, pp: int) -> tuple[str, int]:
|
||||||
pp: int) -> Tuple[str, int]:
|
|
||||||
"""
|
"""
|
||||||
Registers the client with the server.
|
Registers the client with the server.
|
||||||
|
|
||||||
@@ -168,20 +158,13 @@ class ElasticClient:
|
|||||||
free_port = find_free_port()
|
free_port = find_free_port()
|
||||||
data = {
|
data = {
|
||||||
"label": "JOIN",
|
"label": "JOIN",
|
||||||
"content": {
|
"content": {"device_id": device_id, "model_path": model_path, "tp": tp, "pp": pp, "port": free_port},
|
||||||
'device_id': device_id,
|
|
||||||
'model_path': model_path,
|
|
||||||
'tp': tp,
|
|
||||||
'pp': pp,
|
|
||||||
'port': free_port
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.send_str(json.dumps(data))
|
self.send_str(json.dumps(data))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Send data {data} to server fails, detail: {e}")
|
||||||
f"Send data {data} to server fails, detail: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ack_str = self.recv_str()
|
ack_str = self.recv_str()
|
||||||
@@ -191,23 +174,22 @@ class ElasticClient:
|
|||||||
try:
|
try:
|
||||||
ack = json.loads(ack_str)
|
ack = json.loads(ack_str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}")
|
||||||
f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Receive ack: {ack}")
|
logger.info(f"Receive ack: {ack}")
|
||||||
|
|
||||||
if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack
|
if (
|
||||||
and ack["content"] is not None and "name" in ack["content"]):
|
"label" in ack
|
||||||
|
and ack["label"] == "JOIN_ACK"
|
||||||
|
and "content" in ack
|
||||||
|
and ack["content"] is not None
|
||||||
|
and "name" in ack["content"]
|
||||||
|
):
|
||||||
return (ack["content"]["name"], free_port)
|
return (ack["content"]["name"], free_port)
|
||||||
elif ("label" in ack and ack["label"] == 'JOIN_NACK'
|
elif "label" in ack and ack["label"] == "JOIN_NACK" and "content" in ack:
|
||||||
and "content" in ack):
|
raise RuntimeError(f"Receive nack from server, reason: {ack['content']}")
|
||||||
raise RuntimeError(
|
|
||||||
f"Receive nack from server, reason: {ack['content']}")
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Receive ack {ack} from server does not contain required fields")
|
||||||
f"Receive ack {ack} from server does not contain required fields"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ElasticServer:
|
class ElasticServer:
|
||||||
@@ -215,9 +197,18 @@ class ElasticServer:
|
|||||||
Class for handling the server-side logic of Netloader of models.
|
Class for handling the server-side logic of Netloader of models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, addr: str, port: int, model, device_id: int,
|
def __init__(
|
||||||
model_path: str, tp: int, pp: int, int8_cache: str,
|
self,
|
||||||
int8_cache_name: Optional[List[str]]):
|
addr: str,
|
||||||
|
port: int,
|
||||||
|
model,
|
||||||
|
device_id: int,
|
||||||
|
model_path: str,
|
||||||
|
tp: int,
|
||||||
|
pp: int,
|
||||||
|
int8_cache: str,
|
||||||
|
int8_cache_name: list[str] | None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the ElasticServer instance.
|
Initializes the ElasticServer instance.
|
||||||
|
|
||||||
@@ -246,30 +237,25 @@ class ElasticServer:
|
|||||||
self.pp = pp
|
self.pp = pp
|
||||||
|
|
||||||
self.original_int8 = {}
|
self.original_int8 = {}
|
||||||
int8_pattern = "|".join(
|
int8_pattern = "|".join(map(re.escape, int8_cache_name)) if int8_cache_name is not None else "(?:)"
|
||||||
map(re.escape,
|
|
||||||
int8_cache_name)) if int8_cache_name is not None else "(?:)"
|
|
||||||
for name, param in self.model.named_parameters():
|
for name, param in self.model.named_parameters():
|
||||||
if param.dtype == torch.int8:
|
if param.dtype == torch.int8:
|
||||||
if int8_cache == 'hbm':
|
if int8_cache == "hbm":
|
||||||
if int8_cache_name is None or (
|
if int8_cache_name is None or (
|
||||||
int8_cache_name is not None
|
int8_cache_name is not None and re.search(int8_pattern, name) is not None
|
||||||
and re.search(int8_pattern, name) is not None):
|
):
|
||||||
try:
|
try:
|
||||||
self.original_int8[name] = param.data.clone(
|
self.original_int8[name] = param.data.clone().detach()
|
||||||
).detach()
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(
|
logger.error(f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}")
|
||||||
f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}"
|
|
||||||
)
|
|
||||||
self.original_int8[name] = param.data.cpu()
|
self.original_int8[name] = param.data.cpu()
|
||||||
|
|
||||||
elif int8_cache == 'dram':
|
elif int8_cache == "dram":
|
||||||
if int8_cache_name is None or (
|
if int8_cache_name is None or (
|
||||||
int8_cache_name is not None
|
int8_cache_name is not None and re.search(int8_pattern, name) is not None
|
||||||
and re.search(int8_pattern, name) is not None):
|
):
|
||||||
self.original_int8[name] = param.data.cpu()
|
self.original_int8[name] = param.data.cpu()
|
||||||
elif int8_cache == 'no':
|
elif int8_cache == "no":
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -277,14 +263,18 @@ class ElasticServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}"
|
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, "
|
||||||
|
f"model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, "
|
||||||
|
f"int8 params {list(self.original_int8)} are saved to {int8_cache}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""
|
"""
|
||||||
Destructor method to ensure socket is closed.
|
Destructor method to ensure socket is closed.
|
||||||
"""
|
"""
|
||||||
self.s.close()
|
if self.s is not None:
|
||||||
|
with suppress(Exception):
|
||||||
|
self.s.close()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""
|
"""
|
||||||
@@ -343,10 +333,7 @@ class ElasticServer:
|
|||||||
if not all(k in content for k in required_keys):
|
if not all(k in content for k in required_keys):
|
||||||
return False
|
return False
|
||||||
port = content["port"]
|
port = content["port"]
|
||||||
if not (isinstance(port, int) or
|
return isinstance(port, int) or (isinstance(port, str) and port.isdigit())
|
||||||
(isinstance(port, str) and port.isdigit())):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
comm_name = None
|
comm_name = None
|
||||||
if is_valid_data(data):
|
if is_valid_data(data):
|
||||||
@@ -355,36 +342,31 @@ class ElasticServer:
|
|||||||
tp = int(data["content"]["tp"])
|
tp = int(data["content"]["tp"])
|
||||||
pp = int(data["content"]["pp"])
|
pp = int(data["content"]["pp"])
|
||||||
|
|
||||||
if int(self.device_id
|
if (
|
||||||
) == device_id and self.model_path == model_path and int(
|
int(self.device_id) == device_id
|
||||||
self.tp) == tp and int(self.pp) == pp:
|
and self.model_path == model_path
|
||||||
|
and int(self.tp) == tp
|
||||||
|
and int(self.pp) == pp
|
||||||
|
):
|
||||||
comm_name = str(addr[0]) + ":" + str(addr[1])
|
comm_name = str(addr[0]) + ":" + str(addr[1])
|
||||||
ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
|
ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
server_desc = (int(self.device_id), self.model_path, int(self.tp), int(self.pp))
|
||||||
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
|
client_desc = (device_id, model_path, tp, pp)
|
||||||
)
|
msg = f"Received data {client_desc} does not consist with this server {server_desc}"
|
||||||
|
logger.warning(msg)
|
||||||
ack = {
|
ack = {
|
||||||
"label":
|
"label": "JOIN_NACK",
|
||||||
"JOIN_NACK",
|
"content": msg,
|
||||||
"content":
|
|
||||||
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(f"Received data does not contain required fields: {data}")
|
||||||
f"Received data does not contain required fields: {data}")
|
ack = {"label": "JOIN_NACK", "content": f"Received data does not contain required fields: {data}"}
|
||||||
ack = {
|
|
||||||
"label":
|
|
||||||
"JOIN_NACK",
|
|
||||||
"content":
|
|
||||||
f"Received data does not contain required fields: {data}"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ack_str = json.dumps(ack).encode("utf-8")
|
ack_str = json.dumps(ack).encode("utf-8")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Failed to convert {ack} to JSON format, details: {e}")
|
||||||
f"Failed to convert {ack} to JSON format, details: {e}")
|
|
||||||
conn.close()
|
conn.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -395,14 +377,10 @@ class ElasticServer:
|
|||||||
conn.close()
|
conn.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
if ack["content"] and isinstance(ack["content"],
|
if ack["content"] and isinstance(ack["content"], dict) and "name" in ack["content"]:
|
||||||
dict) and 'name' in ack["content"]:
|
|
||||||
try:
|
try:
|
||||||
p2psend = P2PSend(self.addr, data["content"]["port"],
|
p2psend = P2PSend(self.addr, data["content"]["port"], ack["content"]["name"])
|
||||||
ack["content"]["name"])
|
|
||||||
p2psend.send(self.model, self.original_int8)
|
p2psend.send(self.model, self.original_int8)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"P2PSend Failed to send model to {self.addr}, details: {e}")
|
||||||
f"P2PSend Failed to send model to {self.addr}, details: {e}"
|
|
||||||
)
|
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|||||||
@@ -48,36 +48,27 @@ def elastic_load(
|
|||||||
# Filter sources for the current device
|
# Filter sources for the current device
|
||||||
sources_this_device = []
|
sources_this_device = []
|
||||||
for s in sources:
|
for s in sources:
|
||||||
if isinstance(
|
if isinstance(s, dict) and "device_id" in s and s["device_id"] == device_id and isinstance(s["sources"], list):
|
||||||
s, dict
|
|
||||||
) and "device_id" in s and s["device_id"] == device_id and isinstance(
|
|
||||||
s["sources"], list):
|
|
||||||
sources_this_device += s["sources"]
|
sources_this_device += s["sources"]
|
||||||
if len(sources_this_device) == 0:
|
if len(sources_this_device) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the interaction layer with the ElasticClient
|
# Initialize the interaction layer with the ElasticClient
|
||||||
with ElasticClient(sources_this_device, device_id, model_path, tp,
|
with ElasticClient(sources_this_device, device_id, model_path, tp, pp) as client_interaction_layer:
|
||||||
pp) as client_interaction_layer:
|
|
||||||
if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
|
if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Failed to initialize ElasticClient: socket or server_addr is None")
|
||||||
"Failed to initialize ElasticClient: socket or server_addr is None"
|
|
||||||
)
|
|
||||||
ack = client_interaction_layer.ack
|
ack = client_interaction_layer.ack
|
||||||
if ack is None:
|
if ack is None:
|
||||||
raise RuntimeError("ElasticClient.register did not return ack")
|
raise RuntimeError("ElasticClient.register did not return ack")
|
||||||
|
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
elastic_loader = P2PLoad(ack[0],
|
elastic_loader = P2PLoad(ack[0], client_interaction_layer.server_addr, ack[1])
|
||||||
client_interaction_layer.server_addr,
|
|
||||||
ack[1])
|
|
||||||
model_loaded = elastic_loader.load(model=model)
|
model_loaded = elastic_loader.load(model=model)
|
||||||
if model_loaded is None:
|
if model_loaded is None:
|
||||||
logger.error("Failed to load model")
|
logger.error("Failed to load model")
|
||||||
return None
|
return None
|
||||||
logger.info("Finish elastic load (duration: {}s)".format(
|
logger.info("Finish elastic load (duration: {}s)".format(time.perf_counter() - t0))
|
||||||
time.perf_counter() - t0))
|
|
||||||
return model_loaded
|
return model_loaded
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"elastic_load error: {e}")
|
logger.info(f"elastic_load error: {e}")
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import gc
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -27,8 +26,7 @@ from vllm.logger import logger
|
|||||||
from vllm.model_executor.model_loader import register_model_loader
|
from vllm.model_executor.model_loader import register_model_loader
|
||||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import initialize_model, process_weights_after_loading
|
||||||
initialize_model, process_weights_after_loading)
|
|
||||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||||
|
|
||||||
from .interaction.elastic import ElasticServer
|
from .interaction.elastic import ElasticServer
|
||||||
@@ -41,12 +39,13 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
"""
|
"""
|
||||||
A model loader that uses elastic loading for loading weights.
|
A model loader that uses elastic loading for loading weights.
|
||||||
"""
|
"""
|
||||||
source: Optional[List[dict]]
|
|
||||||
model_path: Optional[str]
|
source: list[dict] | None
|
||||||
listen_port: Optional[int]
|
model_path: str | None
|
||||||
|
listen_port: int | None
|
||||||
int8_cache: str
|
int8_cache: str
|
||||||
int8_cache_name: Optional[List[str]]
|
int8_cache_name: list[str] | None
|
||||||
output_prefix: Optional[str]
|
output_prefix: str | None
|
||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
"""
|
"""
|
||||||
@@ -63,18 +62,15 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
extra = load_config.model_loader_extra_config
|
extra = load_config.model_loader_extra_config
|
||||||
if extra and "CONFIG_FILE" in extra:
|
if extra and "CONFIG_FILE" in extra:
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ...")
|
||||||
f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ..."
|
with open(extra["CONFIG_FILE"]) as f:
|
||||||
)
|
|
||||||
with open(extra["CONFIG_FILE"], 'r') as f:
|
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error("CONFIG_FILE not found")
|
logger.error("CONFIG_FILE not found")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error("CONFIG_FILE is not a valid JSON file")
|
logger.error("CONFIG_FILE is not a valid JSON file")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Unexpected error while reading CONFIG_FILE: {e}")
|
||||||
f"Unexpected error while reading CONFIG_FILE: {e}")
|
|
||||||
|
|
||||||
if config is None and extra:
|
if config is None and extra:
|
||||||
logger.info("Reading configs in model_loader_extra_config ...")
|
logger.info("Reading configs in model_loader_extra_config ...")
|
||||||
@@ -82,19 +78,30 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
config = config or {}
|
config = config or {}
|
||||||
|
|
||||||
for key, attr, checker, caster, default in [
|
for key, attr, checker, caster, default in [
|
||||||
("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v,
|
("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v, None),
|
||||||
None),
|
("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v, None),
|
||||||
("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v,
|
(
|
||||||
None),
|
"LISTEN_PORT",
|
||||||
("LISTEN_PORT", "listen_port", lambda v: isinstance(v, int) or
|
"listen_port",
|
||||||
(isinstance(v, str) and v.isdigit()), lambda v: int(v), None),
|
lambda v: isinstance(v, int) or (isinstance(v, str) and v.isdigit()),
|
||||||
("INT8_CACHE", "int8_cache", lambda v: isinstance(v, str) and v.
|
lambda v: int(v),
|
||||||
lower() in ['hbm', 'dram', 'no'], lambda v: v.lower(), 'no'),
|
None,
|
||||||
("INT8_CACHE_NAME", "int8_cache_name",
|
),
|
||||||
lambda v: isinstance(v, list), lambda v: v, None),
|
(
|
||||||
("OUTPUT_PREFIX", "output_prefix",
|
"INT8_CACHE",
|
||||||
lambda v: isinstance(v, str) and is_valid_path_prefix(v),
|
"int8_cache",
|
||||||
lambda v: v, None),
|
lambda v: isinstance(v, str) and v.lower() in ["hbm", "dram", "no"],
|
||||||
|
lambda v: v.lower(),
|
||||||
|
"no",
|
||||||
|
),
|
||||||
|
("INT8_CACHE_NAME", "int8_cache_name", lambda v: isinstance(v, list), lambda v: v, None),
|
||||||
|
(
|
||||||
|
"OUTPUT_PREFIX",
|
||||||
|
"output_prefix",
|
||||||
|
lambda v: isinstance(v, str) and is_valid_path_prefix(v),
|
||||||
|
lambda v: v,
|
||||||
|
None,
|
||||||
|
),
|
||||||
]:
|
]:
|
||||||
v = config.get(key, default)
|
v = config.get(key, default)
|
||||||
if not checker(v):
|
if not checker(v):
|
||||||
@@ -116,8 +123,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
self.output_prefix,
|
self.output_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model(self, vllm_config: VllmConfig,
|
def load_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> nn.Module:
|
||||||
model_config: ModelConfig) -> nn.Module:
|
|
||||||
"""
|
"""
|
||||||
Loads the model using the specified configuration.
|
Loads the model using the specified configuration.
|
||||||
|
|
||||||
@@ -140,15 +146,18 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
|
|
||||||
device_id = torch.distributed.get_rank()
|
device_id = torch.distributed.get_rank()
|
||||||
|
|
||||||
if (self.source is None or not isinstance(self.source, list)
|
if (
|
||||||
or device_id not in [
|
self.source is None
|
||||||
one_device["device_id"] for one_device in self.source if
|
or not isinstance(self.source, list)
|
||||||
isinstance(one_device, dict) and "device_id" in one_device
|
or device_id
|
||||||
]):
|
not in [
|
||||||
logger.warning(
|
one_device["device_id"]
|
||||||
"Did not get valid source info, use DefaultModelLoader")
|
for one_device in self.source
|
||||||
model, need_process_weights_after_loading = self.revert_to_default(
|
if isinstance(one_device, dict) and "device_id" in one_device
|
||||||
model_config, vllm_config, device_config)
|
]
|
||||||
|
):
|
||||||
|
logger.warning("Did not get valid source info, use DefaultModelLoader")
|
||||||
|
model, need_process_weights_after_loading = self.revert_to_default(model_config, vllm_config, device_config)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
@@ -158,8 +167,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
model = initialize_model(vllm_config=vllm_config,
|
model = initialize_model(vllm_config=vllm_config, model_config=model_config)
|
||||||
model_config=model_config)
|
|
||||||
|
|
||||||
start_elastic_load = time.perf_counter()
|
start_elastic_load = time.perf_counter()
|
||||||
model = elastic_load(
|
model = elastic_load(
|
||||||
@@ -171,43 +179,39 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
pp=parallel_config.pipeline_parallel_size,
|
pp=parallel_config.pipeline_parallel_size,
|
||||||
)
|
)
|
||||||
end_elastic_load = time.perf_counter()
|
end_elastic_load = time.perf_counter()
|
||||||
logger.info(
|
logger.info(f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}")
|
||||||
f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}"
|
|
||||||
)
|
|
||||||
need_process_weights_after_loading = True
|
need_process_weights_after_loading = True
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
logger.warning(
|
logger.warning("Netloader elastic loading fails, use load format DefaultModelLoader")
|
||||||
"Netloader elastic loading fails, use load format DefaultModelLoader"
|
|
||||||
)
|
|
||||||
|
|
||||||
vllm_config = vllm_config_backup
|
vllm_config = vllm_config_backup
|
||||||
model_config = model_config_backup
|
model_config = model_config_backup
|
||||||
|
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if device_config.device_type == 'npu':
|
if device_config.device_type == "npu":
|
||||||
logger.info("Empty NPU cache")
|
logger.info("Empty NPU cache")
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
elif device_config.device_type == 'cuda':
|
elif device_config.device_type == "cuda":
|
||||||
logger.info("Empty CUDA cache")
|
logger.info("Empty CUDA cache")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model, need_process_weights_after_loading = self.revert_to_default(
|
model, need_process_weights_after_loading = self.revert_to_default(
|
||||||
model_config, vllm_config, device_config)
|
model_config, vllm_config, device_config
|
||||||
|
)
|
||||||
|
|
||||||
start_elastic_server = time.perf_counter()
|
start_elastic_server = time.perf_counter()
|
||||||
# start elastic server
|
# start elastic server
|
||||||
if model is not None and (
|
if model is not None and (
|
||||||
(self.listen_port and self.listen_port in range(1024, 65535)) or
|
(self.listen_port and self.listen_port in range(1024, 65535)) or (self.listen_port is None)
|
||||||
(self.listen_port is None)):
|
):
|
||||||
|
|
||||||
from vllm.utils.network_utils import get_ip
|
from vllm.utils.network_utils import get_ip
|
||||||
|
|
||||||
driver_ip = get_ip()
|
driver_ip = get_ip()
|
||||||
|
|
||||||
if driver_ip == '0.0.0.0':
|
if driver_ip == "0.0.0.0":
|
||||||
logger.error(
|
logger.error("Driver IP is not set, skip to start Netloader server")
|
||||||
"Driver IP is not set, skip to start Netloader server")
|
|
||||||
else:
|
else:
|
||||||
if self.listen_port is None:
|
if self.listen_port is None:
|
||||||
self.listen_port = find_free_port()
|
self.listen_port = find_free_port()
|
||||||
@@ -220,21 +224,14 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
|
|
||||||
if self.output_prefix is not None:
|
if self.output_prefix is not None:
|
||||||
try:
|
try:
|
||||||
with open(self.output_prefix + str(device_id) + '.txt',
|
with open(self.output_prefix + str(device_id) + ".txt", "w") as file:
|
||||||
'w') as file:
|
|
||||||
file.write(f"{driver_ip}:{self.listen_port}")
|
file.write(f"{driver_ip}:{self.listen_port}")
|
||||||
logger.info(
|
logger.info(f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}")
|
||||||
f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}"
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(
|
logger.error(f"File path {self.output_prefix + str(device_id)} does not exist.")
|
||||||
f"File path {self.output_prefix + str(device_id)} does not exist."
|
|
||||||
)
|
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
logger.error(
|
logger.error(f"No permission to write to file {self.output_prefix + str(device_id)}.")
|
||||||
f"No permission to write to file {self.output_prefix + str(device_id)}."
|
except OSError as e:
|
||||||
)
|
|
||||||
except IOError as e:
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}"
|
f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}"
|
||||||
)
|
)
|
||||||
@@ -242,31 +239,30 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
logger.error(f"Unknown error: {e}")
|
logger.error(f"Unknown error: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert isinstance(
|
assert isinstance(self.listen_port, int), f"listen port should be int but get {self.listen_port}"
|
||||||
self.listen_port, int
|
|
||||||
), f"listen port should be int but get {self.listen_port}"
|
|
||||||
|
|
||||||
elastic_server = ElasticServer(
|
elastic_server = ElasticServer(
|
||||||
driver_ip, self.listen_port, model, device_id,
|
driver_ip,
|
||||||
self.model_path, parallel_config.tensor_parallel_size,
|
self.listen_port,
|
||||||
|
model,
|
||||||
|
device_id,
|
||||||
|
self.model_path,
|
||||||
|
parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size,
|
parallel_config.pipeline_parallel_size,
|
||||||
self.int8_cache, self.int8_cache_name)
|
self.int8_cache,
|
||||||
|
self.int8_cache_name,
|
||||||
|
)
|
||||||
elastic_server.start()
|
elastic_server.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Failed to start Netloader server for rank: {device_id}, details: {e}")
|
||||||
f"Failed to start Netloader server for rank: {device_id}, details: {e}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info("Skip to start Netloader server")
|
logger.info("Skip to start Netloader server")
|
||||||
|
|
||||||
end_elastic_server = time.perf_counter()
|
end_elastic_server = time.perf_counter()
|
||||||
logger.info(
|
logger.info(f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}")
|
||||||
f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if need_process_weights_after_loading:
|
if need_process_weights_after_loading:
|
||||||
process_weights_after_loading(model, model_config,
|
process_weights_after_loading(model, model_config, torch.device(device_config.device))
|
||||||
torch.device(device_config.device))
|
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
logger.error("NetLoader elastic loads model fails")
|
logger.error("NetLoader elastic loads model fails")
|
||||||
@@ -274,8 +270,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
def revert_to_default(self, model_config, vllm_config,
|
def revert_to_default(self, model_config, vllm_config, device_config) -> tuple[nn.Module, bool]:
|
||||||
device_config) -> Tuple[nn.Module, bool]:
|
|
||||||
"""
|
"""
|
||||||
Reverts to the default model loading logic when elastic loading fails or is not applicable.
|
Reverts to the default model loading logic when elastic loading fails or is not applicable.
|
||||||
|
|
||||||
@@ -300,19 +295,15 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
default_model_loader = DefaultModelLoader(self.load_config)
|
default_model_loader = DefaultModelLoader(self.load_config)
|
||||||
|
|
||||||
if model_config.quantization is None:
|
if model_config.quantization is None:
|
||||||
model = default_model_loader.load_model(vllm_config=vllm_config,
|
model = default_model_loader.load_model(vllm_config=vllm_config, model_config=model_config)
|
||||||
model_config=model_config)
|
|
||||||
need_process_weights_after_loading = False
|
need_process_weights_after_loading = False
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning("Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading ")
|
||||||
"Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading "
|
|
||||||
)
|
|
||||||
need_process_weights_after_loading = True
|
need_process_weights_after_loading = True
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
model = initialize_model(vllm_config=vllm_config,
|
model = initialize_model(vllm_config=vllm_config, model_config=model_config)
|
||||||
model_config=model_config)
|
|
||||||
default_model_loader.load_weights(model, model_config)
|
default_model_loader.load_weights(model, model_config)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
@@ -321,6 +312,5 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
|||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_weights(self, model: nn.Module,
|
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||||
model_config: ModelConfig) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def find_free_port():
|
|||||||
- A free port number.
|
- A free port number.
|
||||||
"""
|
"""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.bind(('', 0))
|
s.bind(("", 0))
|
||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
@@ -47,20 +47,17 @@ def is_valid_path_prefix(path_prefix):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if re.search(r'[<>:"|?*]', path_prefix):
|
if re.search(r'[<>:"|?*]', path_prefix):
|
||||||
logger.warning(
|
logger.warning(f"The path prefix {path_prefix} contains illegal characters.")
|
||||||
f'The path prefix {path_prefix} contains illegal characters.')
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if path_prefix.startswith('/') or path_prefix.startswith('\\'):
|
if path_prefix.startswith("/") or path_prefix.startswith("\\"):
|
||||||
if not os.path.exists(os.path.dirname(path_prefix)):
|
if not os.path.exists(os.path.dirname(path_prefix)):
|
||||||
logger.warning(
|
logger.warning(f"The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.")
|
||||||
f'The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.'
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))):
|
if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f'The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist.'
|
f"The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -23,9 +23,8 @@ import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
|||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.utils import vllm_version_is
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv(
|
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
|
||||||
"EXPERT_MAP_RECORD", "false") == "true":
|
|
||||||
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
|
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
|
||||||
|
|
||||||
if envs.VLLM_ASCEND_BALANCE_SCHEDULING and vllm_version_is('0.14.0'):
|
if envs.VLLM_ASCEND_BALANCE_SCHEDULING and vllm_version_is("0.14.0"):
|
||||||
import vllm_ascend.patch.platform.patch_balance_schedule # noqa
|
import vllm_ascend.patch.platform.patch_balance_schedule # noqa
|
||||||
|
|||||||
@@ -7,17 +7,14 @@ import torch.distributed as dist
|
|||||||
import vllm
|
import vllm
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||||
KVConnectorMetadata
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.transformers_utils.config import \
|
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||||
maybe_register_config_serialize_by_value
|
|
||||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||||
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||||
create_request_queue)
|
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
|
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
|
||||||
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
|
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
|
||||||
@@ -30,7 +27,6 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class BalanceScheduler(Scheduler):
|
class BalanceScheduler(Scheduler):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config,
|
vllm_config,
|
||||||
@@ -41,9 +37,15 @@ class BalanceScheduler(Scheduler):
|
|||||||
include_finished_set: bool = False,
|
include_finished_set: bool = False,
|
||||||
log_stats: bool = False,
|
log_stats: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(vllm_config, kv_cache_config,
|
super().__init__(
|
||||||
structured_output_manager, block_size, mm_registry,
|
vllm_config,
|
||||||
include_finished_set, log_stats)
|
kv_cache_config,
|
||||||
|
structured_output_manager,
|
||||||
|
block_size,
|
||||||
|
mm_registry,
|
||||||
|
include_finished_set,
|
||||||
|
log_stats,
|
||||||
|
)
|
||||||
# Balance scheduling.
|
# Balance scheduling.
|
||||||
self.balance_queue = [
|
self.balance_queue = [
|
||||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||||
@@ -51,9 +53,7 @@ class BalanceScheduler(Scheduler):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def balance_gather(self, dp_group):
|
def balance_gather(self, dp_group):
|
||||||
running_tensor = torch.tensor([len(self.running)],
|
running_tensor = torch.tensor([len(self.running)], dtype=torch.int, device="cpu")
|
||||||
dtype=torch.int,
|
|
||||||
device="cpu")
|
|
||||||
dist.all_gather(self.balance_queue, running_tensor, group=dp_group)
|
dist.all_gather(self.balance_queue, running_tensor, group=dp_group)
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
@@ -89,33 +89,32 @@ class BalanceScheduler(Scheduler):
|
|||||||
while req_index < len(self.running) and token_budget > 0:
|
while req_index < len(self.running) and token_budget > 0:
|
||||||
request = self.running[req_index]
|
request = self.running[req_index]
|
||||||
|
|
||||||
if (request.num_output_placeholders > 0
|
if (
|
||||||
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
|
request.num_output_placeholders > 0
|
||||||
# Since output placeholders are also included in the computed tokens
|
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
|
||||||
# count, we subtract (num_output_placeholders - 1) to remove any draft
|
# Since output placeholders are also included in the computed tokens
|
||||||
# tokens, so that we can be sure no further steps are needed even if
|
# count, we subtract (num_output_placeholders - 1) to remove any draft
|
||||||
# they are all rejected.
|
# tokens, so that we can be sure no further steps are needed even if
|
||||||
and request.num_computed_tokens + 2 -
|
# they are all rejected.
|
||||||
request.num_output_placeholders
|
and request.num_computed_tokens + 2 - request.num_output_placeholders
|
||||||
>= request.num_prompt_tokens + request.max_tokens):
|
>= request.num_prompt_tokens + request.max_tokens
|
||||||
|
):
|
||||||
# Async scheduling: Avoid scheduling an extra step when we are sure that
|
# Async scheduling: Avoid scheduling an extra step when we are sure that
|
||||||
# the previous step has reached request.max_tokens. We don't schedule
|
# the previous step has reached request.max_tokens. We don't schedule
|
||||||
# partial draft tokens since this prevents uniform decode optimizations.
|
# partial draft tokens since this prevents uniform decode optimizations.
|
||||||
req_index += 1
|
req_index += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_new_tokens = (request.num_tokens_with_spec +
|
num_new_tokens = (
|
||||||
request.num_output_placeholders -
|
request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens
|
||||||
request.num_computed_tokens)
|
)
|
||||||
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
|
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
|
||||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
|
|
||||||
# Make sure the input position does not exceed the max model len.
|
# Make sure the input position does not exceed the max model len.
|
||||||
# This is necessary when using spec decoding.
|
# This is necessary when using spec decoding.
|
||||||
num_new_tokens = min(
|
num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
|
||||||
num_new_tokens,
|
|
||||||
self.max_model_len - 1 - request.num_computed_tokens)
|
|
||||||
|
|
||||||
# Schedule encoder inputs.
|
# Schedule encoder inputs.
|
||||||
encoder_inputs_to_schedule = None
|
encoder_inputs_to_schedule = None
|
||||||
@@ -174,20 +173,17 @@ class BalanceScheduler(Scheduler):
|
|||||||
self.running.remove(preempted_req)
|
self.running.remove(preempted_req)
|
||||||
if preempted_req in scheduled_running_reqs:
|
if preempted_req in scheduled_running_reqs:
|
||||||
scheduled_running_reqs.remove(preempted_req)
|
scheduled_running_reqs.remove(preempted_req)
|
||||||
token_budget += num_scheduled_tokens[
|
token_budget += num_scheduled_tokens[preempted_req.request_id]
|
||||||
preempted_req.request_id]
|
|
||||||
req_to_new_blocks.pop(preempted_req.request_id)
|
req_to_new_blocks.pop(preempted_req.request_id)
|
||||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||||
scheduled_spec_decode_tokens.pop(
|
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
|
||||||
preempted_req.request_id, None)
|
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
|
||||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
|
|
||||||
preempted_req.request_id, None)
|
|
||||||
if preempted_encoder_inputs:
|
if preempted_encoder_inputs:
|
||||||
# Restore encoder compute budget if the preempted
|
# Restore encoder compute budget if the preempted
|
||||||
# request had encoder inputs scheduled in this step.
|
# request had encoder inputs scheduled in this step.
|
||||||
num_embeds_to_restore = sum(
|
num_embeds_to_restore = sum(
|
||||||
preempted_req.get_num_encoder_embeds(i)
|
preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs
|
||||||
for i in preempted_encoder_inputs)
|
)
|
||||||
encoder_compute_budget += num_embeds_to_restore
|
encoder_compute_budget += num_embeds_to_restore
|
||||||
req_index -= 1
|
req_index -= 1
|
||||||
else:
|
else:
|
||||||
@@ -212,23 +208,20 @@ class BalanceScheduler(Scheduler):
|
|||||||
|
|
||||||
# Speculative decode related.
|
# Speculative decode related.
|
||||||
if request.spec_token_ids:
|
if request.spec_token_ids:
|
||||||
num_scheduled_spec_tokens = (num_new_tokens +
|
num_scheduled_spec_tokens = (
|
||||||
request.num_computed_tokens -
|
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
|
||||||
request.num_tokens -
|
)
|
||||||
request.num_output_placeholders)
|
|
||||||
if num_scheduled_spec_tokens > 0:
|
if num_scheduled_spec_tokens > 0:
|
||||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||||
scheduled_spec_decode_tokens[request.request_id] = (
|
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
|
||||||
request.spec_token_ids)
|
|
||||||
# New spec tokens will be set in `update_draft_token_ids` before the
|
# New spec tokens will be set in `update_draft_token_ids` before the
|
||||||
# next step when applicable.
|
# next step when applicable.
|
||||||
request.spec_token_ids = []
|
request.spec_token_ids = []
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
if encoder_inputs_to_schedule:
|
if encoder_inputs_to_schedule:
|
||||||
scheduled_encoder_inputs[request.request_id] = (
|
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||||
encoder_inputs_to_schedule)
|
|
||||||
# Allocate the encoder cache.
|
# Allocate the encoder cache.
|
||||||
for i in encoder_inputs_to_schedule:
|
for i in encoder_inputs_to_schedule:
|
||||||
self.encoder_cache_manager.allocate(request, i)
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
@@ -243,8 +236,10 @@ class BalanceScheduler(Scheduler):
|
|||||||
scheduled_loras: set[int] = set()
|
scheduled_loras: set[int] = set()
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
scheduled_loras = set(
|
scheduled_loras = set(
|
||||||
req.lora_request.lora_int_id for req in scheduled_running_reqs
|
req.lora_request.lora_int_id
|
||||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
for req in scheduled_running_reqs
|
||||||
|
if req.lora_request and req.lora_request.lora_int_id > 0
|
||||||
|
)
|
||||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||||
|
|
||||||
# Use a temporary RequestQueue to collect requests that need to be
|
# Use a temporary RequestQueue to collect requests that need to be
|
||||||
@@ -257,9 +252,7 @@ class BalanceScheduler(Scheduler):
|
|||||||
if len(self.running) == self.max_num_running_reqs:
|
if len(self.running) == self.max_num_running_reqs:
|
||||||
break
|
break
|
||||||
|
|
||||||
balance_flag = (max(
|
balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs
|
||||||
t.item()
|
|
||||||
for t in self.balance_queue) == self.max_num_running_reqs)
|
|
||||||
if balance_flag:
|
if balance_flag:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -292,9 +285,14 @@ class BalanceScheduler(Scheduler):
|
|||||||
|
|
||||||
# Check that adding the request still respects the max_loras
|
# Check that adding the request still respects the max_loras
|
||||||
# constraint.
|
# constraint.
|
||||||
if (self.lora_config and request.lora_request and
|
if (
|
||||||
(len(scheduled_loras) == self.lora_config.max_loras and
|
self.lora_config
|
||||||
request.lora_request.lora_int_id not in scheduled_loras)):
|
and request.lora_request
|
||||||
|
and (
|
||||||
|
len(scheduled_loras) == self.lora_config.max_loras
|
||||||
|
and request.lora_request.lora_int_id not in scheduled_loras
|
||||||
|
)
|
||||||
|
):
|
||||||
# Scheduling would exceed max_loras, skip.
|
# Scheduling would exceed max_loras, skip.
|
||||||
self.waiting.pop_request()
|
self.waiting.pop_request()
|
||||||
skipped_waiting_requests.prepend_request(request)
|
skipped_waiting_requests.prepend_request(request)
|
||||||
@@ -306,14 +304,15 @@ class BalanceScheduler(Scheduler):
|
|||||||
# Get already-cached tokens.
|
# Get already-cached tokens.
|
||||||
if request.num_computed_tokens == 0:
|
if request.num_computed_tokens == 0:
|
||||||
# Get locally-cached tokens.
|
# Get locally-cached tokens.
|
||||||
new_computed_blocks, num_new_local_computed_tokens = (
|
new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks(
|
||||||
self.kv_cache_manager.get_computed_blocks(request))
|
request
|
||||||
|
)
|
||||||
|
|
||||||
# Get externally-cached tokens if using a KVConnector.
|
# Get externally-cached tokens if using a KVConnector.
|
||||||
if self.connector is not None:
|
if self.connector is not None:
|
||||||
ext_tokens, load_kv_async = (
|
ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens(
|
||||||
self.connector.get_num_new_matched_tokens(
|
request, num_new_local_computed_tokens
|
||||||
request, num_new_local_computed_tokens))
|
)
|
||||||
|
|
||||||
if ext_tokens is None:
|
if ext_tokens is None:
|
||||||
# The request cannot be scheduled because
|
# The request cannot be scheduled because
|
||||||
@@ -327,8 +326,7 @@ class BalanceScheduler(Scheduler):
|
|||||||
num_external_computed_tokens = ext_tokens
|
num_external_computed_tokens = ext_tokens
|
||||||
|
|
||||||
# Total computed tokens (local + external).
|
# Total computed tokens (local + external).
|
||||||
num_computed_tokens = (num_new_local_computed_tokens +
|
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
|
||||||
num_external_computed_tokens)
|
|
||||||
else:
|
else:
|
||||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||||
# after async KV recvs are completed.
|
# after async KV recvs are completed.
|
||||||
@@ -356,8 +354,7 @@ class BalanceScheduler(Scheduler):
|
|||||||
|
|
||||||
# chunked prefill has to be enabled explicitly to allow
|
# chunked prefill has to be enabled explicitly to allow
|
||||||
# pooling requests to be chunked
|
# pooling requests to be chunked
|
||||||
if (not self.scheduler_config.enable_chunked_prefill
|
if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget:
|
||||||
and num_new_tokens > token_budget):
|
|
||||||
# If chunked_prefill is disabled,
|
# If chunked_prefill is disabled,
|
||||||
# we can stop the scheduling here.
|
# we can stop the scheduling here.
|
||||||
break
|
break
|
||||||
@@ -388,9 +385,7 @@ class BalanceScheduler(Scheduler):
|
|||||||
# extra block gets allocated which
|
# extra block gets allocated which
|
||||||
# creates a mismatch between the number
|
# creates a mismatch between the number
|
||||||
# of local and remote blocks.
|
# of local and remote blocks.
|
||||||
effective_lookahead_tokens = (0 if request.num_computed_tokens
|
effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||||
== 0 else
|
|
||||||
self.num_lookahead_tokens)
|
|
||||||
|
|
||||||
# Determine if we need to allocate cross-attention blocks.
|
# Determine if we need to allocate cross-attention blocks.
|
||||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||||
@@ -398,8 +393,7 @@ class BalanceScheduler(Scheduler):
|
|||||||
# always padded to the maximum length. If we support other
|
# always padded to the maximum length. If we support other
|
||||||
# encoder-decoder models, this will need to be updated if we
|
# encoder-decoder models, this will need to be updated if we
|
||||||
# want to only allocate what is needed.
|
# want to only allocate what is needed.
|
||||||
num_encoder_tokens = (
|
num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens
|
||||||
self.scheduler_config.max_num_encoder_input_tokens)
|
|
||||||
else:
|
else:
|
||||||
num_encoder_tokens = 0
|
num_encoder_tokens = 0
|
||||||
|
|
||||||
@@ -442,20 +436,17 @@ class BalanceScheduler(Scheduler):
|
|||||||
|
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
|
||||||
scheduled_timestamp)
|
|
||||||
if request.status == RequestStatus.WAITING:
|
if request.status == RequestStatus.WAITING:
|
||||||
scheduled_new_reqs.append(request)
|
scheduled_new_reqs.append(request)
|
||||||
elif request.status == RequestStatus.PREEMPTED:
|
elif request.status == RequestStatus.PREEMPTED:
|
||||||
scheduled_resumed_reqs.append(request)
|
scheduled_resumed_reqs.append(request)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||||
f"Invalid request status: {request.status}")
|
|
||||||
|
|
||||||
if self.lora_config and request.lora_request:
|
if self.lora_config and request.lora_request:
|
||||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||||
req_to_new_blocks[request.request_id] = (
|
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
|
||||||
self.kv_cache_manager.get_blocks(request.request_id))
|
|
||||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
@@ -465,8 +456,7 @@ class BalanceScheduler(Scheduler):
|
|||||||
request.num_cached_tokens = num_computed_tokens
|
request.num_cached_tokens = num_computed_tokens
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
if encoder_inputs_to_schedule:
|
if encoder_inputs_to_schedule:
|
||||||
scheduled_encoder_inputs[request.request_id] = (
|
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
|
||||||
encoder_inputs_to_schedule)
|
|
||||||
# Allocate the encoder cache.
|
# Allocate the encoder cache.
|
||||||
for i in encoder_inputs_to_schedule:
|
for i in encoder_inputs_to_schedule:
|
||||||
self.encoder_cache_manager.allocate(request, i)
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
@@ -476,8 +466,7 @@ class BalanceScheduler(Scheduler):
|
|||||||
for i in external_load_encoder_input:
|
for i in external_load_encoder_input:
|
||||||
self.encoder_cache_manager.allocate(request, i)
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
if self.ec_connector is not None:
|
if self.ec_connector is not None:
|
||||||
self.ec_connector.update_state_after_alloc(
|
self.ec_connector.update_state_after_alloc(request, i)
|
||||||
request, i)
|
|
||||||
# Put back any skipped requests at the head of the waiting queue
|
# Put back any skipped requests at the head of the waiting queue
|
||||||
if skipped_waiting_requests:
|
if skipped_waiting_requests:
|
||||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||||
@@ -491,20 +480,15 @@ class BalanceScheduler(Scheduler):
|
|||||||
# Since some requests in the RUNNING queue may not be scheduled in
|
# Since some requests in the RUNNING queue may not be scheduled in
|
||||||
# this step, the total number of scheduled requests can be smaller than
|
# this step, the total number of scheduled requests can be smaller than
|
||||||
# len(self.running).
|
# len(self.running).
|
||||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)
|
||||||
scheduled_running_reqs) <= len(self.running)
|
|
||||||
|
|
||||||
# Get the longest common prefix among all requests in the running queue.
|
# Get the longest common prefix among all requests in the running queue.
|
||||||
# This can be potentially used for cascade attention.
|
# This can be potentially used for cascade attention.
|
||||||
num_common_prefix_blocks = [0] * len(
|
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||||
self.kv_cache_config.kv_cache_groups)
|
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||||
with record_function_or_nullcontext(
|
|
||||||
"schedule: get_num_common_prefix_blocks"):
|
|
||||||
if self.running:
|
if self.running:
|
||||||
any_request = self.running[0]
|
any_request = self.running[0]
|
||||||
num_common_prefix_blocks = (
|
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
|
||||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
|
||||||
any_request.request_id))
|
|
||||||
|
|
||||||
# Construct the scheduler output.
|
# Construct the scheduler output.
|
||||||
if self.use_v2_model_runner:
|
if self.use_v2_model_runner:
|
||||||
@@ -515,17 +499,16 @@ class BalanceScheduler(Scheduler):
|
|||||||
req,
|
req,
|
||||||
req_to_new_blocks[req.request_id].get_block_ids(),
|
req_to_new_blocks[req.request_id].get_block_ids(),
|
||||||
req._all_token_ids,
|
req._all_token_ids,
|
||||||
) for req in scheduled_new_reqs
|
)
|
||||||
|
for req in scheduled_new_reqs
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
new_reqs_data = [
|
new_reqs_data = [
|
||||||
NewRequestData.from_request(
|
NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
|
||||||
for req in scheduled_new_reqs
|
for req in scheduled_new_reqs
|
||||||
]
|
]
|
||||||
|
|
||||||
with record_function_or_nullcontext(
|
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||||
"schedule: make_cached_request_data"):
|
|
||||||
cached_reqs_data = self._make_cached_request_data(
|
cached_reqs_data = self._make_cached_request_data(
|
||||||
scheduled_running_reqs,
|
scheduled_running_reqs,
|
||||||
scheduled_resumed_reqs,
|
scheduled_resumed_reqs,
|
||||||
@@ -546,15 +529,13 @@ class BalanceScheduler(Scheduler):
|
|||||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||||
preempted_req_ids={req.request_id
|
preempted_req_ids={req.request_id for req in preempted_reqs},
|
||||||
for req in preempted_reqs},
|
|
||||||
# finished_req_ids is an existing state in the scheduler,
|
# finished_req_ids is an existing state in the scheduler,
|
||||||
# instead of being newly scheduled in this step.
|
# instead of being newly scheduled in this step.
|
||||||
# It contains the request IDs that are finished in between
|
# It contains the request IDs that are finished in between
|
||||||
# the previous and the current steps.
|
# the previous and the current steps.
|
||||||
finished_req_ids=self.finished_req_ids,
|
finished_req_ids=self.finished_req_ids,
|
||||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||||
get_freed_mm_hashes(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||||
@@ -562,14 +543,12 @@ class BalanceScheduler(Scheduler):
|
|||||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||||
# 3. Clear the internal states of the connector
|
# 3. Clear the internal states of the connector
|
||||||
if self.connector is not None:
|
if self.connector is not None:
|
||||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(
|
meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output)
|
||||||
scheduler_output)
|
|
||||||
scheduler_output.kv_connector_metadata = meta
|
scheduler_output.kv_connector_metadata = meta
|
||||||
|
|
||||||
# Build the connector meta for ECConnector
|
# Build the connector meta for ECConnector
|
||||||
if self.ec_connector is not None:
|
if self.ec_connector is not None:
|
||||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
|
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output)
|
||||||
scheduler_output)
|
|
||||||
scheduler_output.ec_connector_metadata = ec_meta
|
scheduler_output.ec_connector_metadata = ec_meta
|
||||||
|
|
||||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||||
@@ -578,7 +557,6 @@ class BalanceScheduler(Scheduler):
|
|||||||
|
|
||||||
|
|
||||||
class BalanceDPEngineCoreProc(DPEngineCoreProc):
|
class BalanceDPEngineCoreProc(DPEngineCoreProc):
|
||||||
|
|
||||||
def run_busy_loop(self):
|
def run_busy_loop(self):
|
||||||
"""Core busy loop of the EngineCore for data parallel case."""
|
"""Core busy loop of the EngineCore for data parallel case."""
|
||||||
|
|
||||||
@@ -602,23 +580,23 @@ class BalanceDPEngineCoreProc(DPEngineCoreProc):
|
|||||||
self.execute_dummy_batch()
|
self.execute_dummy_batch()
|
||||||
|
|
||||||
# 3) All-reduce operation to determine global unfinished reqs.
|
# 3) All-reduce operation to determine global unfinished reqs.
|
||||||
self.engines_running = self._has_global_unfinished_reqs(
|
self.engines_running = self._has_global_unfinished_reqs(local_unfinished_reqs)
|
||||||
local_unfinished_reqs)
|
|
||||||
self.scheduler.balance_gather(self.dp_group)
|
self.scheduler.balance_gather(self.dp_group)
|
||||||
|
|
||||||
if not self.engines_running:
|
if not self.engines_running:
|
||||||
if self.dp_rank == 0 or not self.has_coordinator:
|
if self.dp_rank == 0 or not self.has_coordinator:
|
||||||
# Notify client that we are pausing the loop.
|
# Notify client that we are pausing the loop.
|
||||||
logger.debug("Wave %d finished, pausing engine loop.",
|
logger.debug("Wave %d finished, pausing engine loop.", self.current_wave)
|
||||||
self.current_wave)
|
|
||||||
# In the coordinator case, dp rank 0 sends updates to the
|
# In the coordinator case, dp rank 0 sends updates to the
|
||||||
# coordinator. Otherwise (offline spmd case), each rank
|
# coordinator. Otherwise (offline spmd case), each rank
|
||||||
# sends the update to its colocated front-end process.
|
# sends the update to its colocated front-end process.
|
||||||
client_index = -1 if self.has_coordinator else 0
|
client_index = -1 if self.has_coordinator else 0
|
||||||
self.output_queue.put_nowait((
|
self.output_queue.put_nowait(
|
||||||
client_index,
|
(
|
||||||
EngineCoreOutputs(wave_complete=self.current_wave),
|
client_index,
|
||||||
))
|
EngineCoreOutputs(wave_complete=self.current_wave),
|
||||||
|
)
|
||||||
|
)
|
||||||
# Increment wave count and reset step counter.
|
# Increment wave count and reset step counter.
|
||||||
self.current_wave += 1
|
self.current_wave += 1
|
||||||
self.step_counter = 0
|
self.step_counter = 0
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
import vllm.distributed.ec_transfer.ec_connector.example_connector
|
import vllm.distributed.ec_transfer.ec_connector.example_connector
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from vllm.distributed.ec_transfer.ec_connector.example_connector import (
|
from vllm.distributed.ec_transfer.ec_connector.example_connector import ECConnectorMetadata, ECExampleConnector
|
||||||
ECConnectorMetadata, ECExampleConnector)
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
|
||||||
class AscendECExampleConnector(ECExampleConnector):
|
class AscendECExampleConnector(ECExampleConnector):
|
||||||
|
|
||||||
def start_load_caches(self, encoder_cache, **kwargs) -> None:
|
def start_load_caches(self, encoder_cache, **kwargs) -> None:
|
||||||
metadata: ECConnectorMetadata = self._get_connector_metadata()
|
metadata: ECConnectorMetadata = self._get_connector_metadata()
|
||||||
assert isinstance(metadata, ECConnectorMetadata)
|
assert isinstance(metadata, ECConnectorMetadata)
|
||||||
assert encoder_cache is not None
|
assert encoder_cache is not None
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
logger.warning((
|
logger.warning(
|
||||||
"In connector.start_load_caches, ",
|
(
|
||||||
"but the connector metadata is None",
|
"In connector.start_load_caches, ",
|
||||||
))
|
"but the connector metadata is None",
|
||||||
|
)
|
||||||
|
)
|
||||||
return
|
return
|
||||||
# Load the EC for each mm data
|
# Load the EC for each mm data
|
||||||
for mm_data in metadata.mm_datas:
|
for mm_data in metadata.mm_datas:
|
||||||
@@ -24,8 +24,7 @@ class AscendECExampleConnector(ECExampleConnector):
|
|||||||
filename = self._generate_filename_debug(mm_data.mm_hash)
|
filename = self._generate_filename_debug(mm_data.mm_hash)
|
||||||
ec_cache = load_file(filename)["ec_cache"].npu()
|
ec_cache = load_file(filename)["ec_cache"].npu()
|
||||||
encoder_cache[mm_data.mm_hash] = ec_cache
|
encoder_cache[mm_data.mm_hash] = ec_cache
|
||||||
logger.debug("Success load encoder cache for hash %s",
|
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
|
||||||
mm_data.mm_hash)
|
|
||||||
|
|
||||||
|
|
||||||
vllm.distributed.ec_transfer.ec_connector.example_connector.ECExampleConnector = AscendECExampleConnector
|
vllm.distributed.ec_transfer.ec_connector.example_connector.ECExampleConnector = AscendECExampleConnector
|
||||||
|
|||||||
@@ -38,7 +38,8 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
|||||||
block_size=1,
|
block_size=1,
|
||||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||||
head_size=model_config.get_head_size(),
|
head_size=model_config.get_head_size(),
|
||||||
dtype=kv_cache_dtype).page_size_bytes
|
dtype=kv_cache_dtype,
|
||||||
|
).page_size_bytes
|
||||||
|
|
||||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||||
model_config.architecture,
|
model_config.architecture,
|
||||||
@@ -58,23 +59,20 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
|||||||
# block size to multiple of 16, so let's suggest a value
|
# block size to multiple of 16, so let's suggest a value
|
||||||
# that would work (note: FA is currently not compatible
|
# that would work (note: FA is currently not compatible
|
||||||
# with mamba layers, use FlashInfer instead).
|
# with mamba layers, use FlashInfer instead).
|
||||||
attn_block_size = block_alignment_bytes * cdiv(
|
attn_block_size = block_alignment_bytes * cdiv(mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
|
||||||
mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
|
|
||||||
|
|
||||||
# override attention block size if either (a) the
|
# override attention block size if either (a) the
|
||||||
# user has not set it or (b) the user has set it
|
# user has not set it or (b) the user has set it
|
||||||
# too small.
|
# too small.
|
||||||
if (cache_config.block_size is None
|
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
|
||||||
or cache_config.block_size < attn_block_size):
|
|
||||||
cache_config.block_size = attn_block_size
|
cache_config.block_size = attn_block_size
|
||||||
logger.info(
|
logger.info(
|
||||||
"Setting attention block size to %d tokens "
|
"Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.",
|
||||||
"to ensure that attention page size is >= mamba page size.",
|
attn_block_size,
|
||||||
attn_block_size)
|
)
|
||||||
|
|
||||||
# compute new attention page size
|
# compute new attention page size
|
||||||
attn_page_size = \
|
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||||
cache_config.block_size * attn_page_size_1_token
|
|
||||||
|
|
||||||
assert attn_page_size >= mamba_page_size
|
assert attn_page_size >= mamba_page_size
|
||||||
|
|
||||||
@@ -83,15 +81,15 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# pad mamba page size to exactly match attention
|
# pad mamba page size to exactly match attention
|
||||||
if (cache_config.mamba_page_size_padded is None
|
if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size:
|
||||||
or cache_config.mamba_page_size_padded != attn_page_size):
|
cache_config.mamba_page_size_padded = attn_page_size
|
||||||
cache_config.mamba_page_size_padded = (attn_page_size)
|
mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
||||||
mamba_padding_pct = 100 * (attn_page_size -
|
|
||||||
mamba_page_size) / mamba_page_size
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Padding mamba page size by %.2f%% to ensure "
|
"Padding mamba page size by %.2f%% to ensure "
|
||||||
"that mamba page size and attention page size are "
|
"that mamba page size and attention page size are "
|
||||||
"exactly equal.", mamba_padding_pct)
|
"exactly equal.",
|
||||||
|
mamba_padding_pct,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|
||||||
|
|||||||
@@ -7,19 +7,20 @@ from multiprocessing.synchronize import Lock as LockType
|
|||||||
import vllm.v1.executor.multiproc_executor
|
import vllm.v1.executor.multiproc_executor
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
||||||
MessageQueue)
|
from vllm.utils.network_utils import get_distributed_init_method, get_loopback_ip, get_open_port
|
||||||
from vllm.utils.network_utils import (get_distributed_init_method,
|
|
||||||
get_loopback_ip, get_open_port)
|
|
||||||
from vllm.utils.system_utils import get_mp_context
|
from vllm.utils.system_utils import get_mp_context
|
||||||
from vllm.v1.executor.abstract import FailureCallback
|
from vllm.v1.executor.abstract import FailureCallback
|
||||||
from vllm.v1.executor.multiproc_executor import (
|
from vllm.v1.executor.multiproc_executor import (
|
||||||
FutureWrapper, MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc,
|
FutureWrapper,
|
||||||
set_multiprocessing_worker_envs)
|
MultiprocExecutor,
|
||||||
|
UnreadyWorkerProcHandle,
|
||||||
|
WorkerProc,
|
||||||
|
set_multiprocessing_worker_envs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AscendMultiprocExecutor(MultiprocExecutor):
|
class AscendMultiprocExecutor(MultiprocExecutor):
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
# Call self.shutdown at exit to clean up
|
# Call self.shutdown at exit to clean up
|
||||||
# and ensure workers will be terminated.
|
# and ensure workers will be terminated.
|
||||||
@@ -32,7 +33,8 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
||||||
f"global world_size ({self.parallel_config.world_size}) must be "
|
f"global world_size ({self.parallel_config.world_size}) must be "
|
||||||
f"divisible by nnodes_within_dp "
|
f"divisible by nnodes_within_dp "
|
||||||
f"({self.parallel_config.nnodes_within_dp}). ")
|
f"({self.parallel_config.nnodes_within_dp}). "
|
||||||
|
)
|
||||||
self.local_world_size = self.parallel_config.local_world_size
|
self.local_world_size = self.parallel_config.local_world_size
|
||||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||||
@@ -41,7 +43,8 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
f"world_size ({self.world_size}) must be equal to the "
|
f"world_size ({self.world_size}) must be equal to the "
|
||||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||||
f"_parallel_size ({pp_parallel_size}) x prefill_context"
|
f"_parallel_size ({pp_parallel_size}) x prefill_context"
|
||||||
f"_parallel_size ({pcp_parallel_size}). ")
|
f"_parallel_size ({pcp_parallel_size}). "
|
||||||
|
)
|
||||||
|
|
||||||
# Set multiprocessing envs
|
# Set multiprocessing envs
|
||||||
set_multiprocessing_worker_envs()
|
set_multiprocessing_worker_envs()
|
||||||
@@ -49,8 +52,7 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
# Multiprocessing-based executor does not support multi-node setting.
|
# Multiprocessing-based executor does not support multi-node setting.
|
||||||
# Since it only works for single node, we can use the loopback address
|
# Since it only works for single node, we can use the loopback address
|
||||||
# get_loopback_ip() for communication.
|
# get_loopback_ip() for communication.
|
||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(get_loopback_ip(), get_open_port())
|
||||||
get_loopback_ip(), get_open_port())
|
|
||||||
self.rpc_broadcast_mq: MessageQueue | None = None
|
self.rpc_broadcast_mq: MessageQueue | None = None
|
||||||
scheduler_output_handle: Handle | None = None
|
scheduler_output_handle: Handle | None = None
|
||||||
# Initialize worker and set up message queues for SchedulerOutputs
|
# Initialize worker and set up message queues for SchedulerOutputs
|
||||||
@@ -72,8 +74,7 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||||
success = False
|
success = False
|
||||||
try:
|
try:
|
||||||
global_start_rank = (self.local_world_size *
|
global_start_rank = self.local_world_size * self.parallel_config.node_rank_within_dp
|
||||||
self.parallel_config.node_rank_within_dp)
|
|
||||||
for local_rank in range(self.local_world_size):
|
for local_rank in range(self.local_world_size):
|
||||||
global_rank = global_start_rank + local_rank
|
global_rank = global_start_rank + local_rank
|
||||||
unready_workers.append(
|
unready_workers.append(
|
||||||
@@ -84,7 +85,8 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
input_shm_handle=scheduler_output_handle,
|
input_shm_handle=scheduler_output_handle,
|
||||||
shared_worker_lock=shared_worker_lock,
|
shared_worker_lock=shared_worker_lock,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Workers must be created before wait_for_ready to avoid
|
# Workers must be created before wait_for_ready to avoid
|
||||||
# deadlock, since worker.init_device() does a device sync.
|
# deadlock, since worker.init_device() does a device sync.
|
||||||
@@ -101,13 +103,11 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
if self.parallel_config.node_rank_within_dp == 0:
|
if self.parallel_config.node_rank_within_dp == 0:
|
||||||
for rank in range(self.world_size):
|
for rank in range(self.world_size):
|
||||||
if rank < self.local_world_size:
|
if rank < self.local_world_size:
|
||||||
local_message_queue = self.workers[
|
local_message_queue = self.workers[rank].worker_response_mq
|
||||||
rank].worker_response_mq
|
|
||||||
assert local_message_queue is not None
|
assert local_message_queue is not None
|
||||||
self.response_mqs.append(local_message_queue)
|
self.response_mqs.append(local_message_queue)
|
||||||
else:
|
else:
|
||||||
remote_message_queue = self.workers[
|
remote_message_queue = self.workers[0].peer_worker_response_mqs[rank]
|
||||||
0].peer_worker_response_mqs[rank]
|
|
||||||
assert remote_message_queue is not None
|
assert remote_message_queue is not None
|
||||||
self.response_mqs.append(remote_message_queue)
|
self.response_mqs.append(remote_message_queue)
|
||||||
|
|
||||||
@@ -128,8 +128,7 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
for uw in unready_workers:
|
for uw in unready_workers:
|
||||||
if uw.death_writer is not None:
|
if uw.death_writer is not None:
|
||||||
uw.death_writer.close()
|
uw.death_writer.close()
|
||||||
self._ensure_worker_termination(
|
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||||
[uw.proc for uw in unready_workers])
|
|
||||||
|
|
||||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||||
|
|
||||||
@@ -137,7 +136,6 @@ class AscendMultiprocExecutor(MultiprocExecutor):
|
|||||||
|
|
||||||
|
|
||||||
class AscendWorkerProc(WorkerProc):
|
class AscendWorkerProc(WorkerProc):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_worker_process(
|
def make_worker_process(
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
|||||||
@@ -3,11 +3,10 @@ import sys
|
|||||||
import vllm.distributed.utils
|
import vllm.distributed.utils
|
||||||
from vllm.platforms import CpuArchEnum, Platform
|
from vllm.platforms import CpuArchEnum, Platform
|
||||||
|
|
||||||
is_arm = (Platform.get_cpu_architecture() == CpuArchEnum.ARM)
|
is_arm = Platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||||
|
|
||||||
USE_SCHED_YIELD = (
|
USE_SCHED_YIELD = (
|
||||||
((sys.version_info[:3] >= (3, 11, 1)) or
|
(sys.version_info[:3] >= (3, 11, 1)) or (sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8)
|
||||||
(sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8))
|
) and not is_arm
|
||||||
and not is_arm)
|
|
||||||
|
|
||||||
vllm.distributed.utils.USE_SCHED_YIELD = USE_SCHED_YIELD
|
vllm.distributed.utils.USE_SCHED_YIELD = USE_SCHED_YIELD
|
||||||
|
|||||||
Reference in New Issue
Block a user