From 4e53c1d90073940e81d6161e07d9205fa225fd78 Mon Sep 17 00:00:00 2001 From: SILONG ZENG <2609716663@qq.com> Date: Sat, 24 Jan 2026 22:08:33 +0800 Subject: [PATCH] [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #6) (#6001) ### What this PR does / why we need it? | File Path | | :--- | | ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` | | ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` | | ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` | | ` vllm_ascend/eplb/core/eplb_utils.py` | | ` vllm_ascend/eplb/core/eplb_worker.py` | | ` vllm_ascend/eplb/core/policy/policy_abstract.py` | | ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` | | ` vllm_ascend/eplb/core/policy/policy_factory.py` | | ` vllm_ascend/eplb/core/policy/policy_flashlb.py` | | ` vllm_ascend/eplb/core/policy/policy_random.py` | | ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` | | ` vllm_ascend/eplb/eplb_updator.py` | | ` vllm_ascend/eplb/utils.py` | | ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` | | ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` | | ` vllm_ascend/model_loader/netloader/interaction/elastic.py` | | ` vllm_ascend/model_loader/netloader/load.py` | | ` vllm_ascend/model_loader/netloader/netloader.py` | | ` vllm_ascend/model_loader/netloader/utils.py` | | ` vllm_ascend/patch/platform/__init__.py` | | ` vllm_ascend/patch/platform/patch_balance_schedule.py` | | ` vllm_ascend/patch/platform/patch_ec_connector.py` | | ` vllm_ascend/patch/platform/patch_mamba_config.py` | | ` vllm_ascend/patch/platform/patch_multiproc_executor.py` | | ` vllm_ascend/patch/platform/patch_sched_yield.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com> --- pyproject.toml | 5 +- vllm_ascend/eplb/adaptor/abstract_adaptor.py | 10 +- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 100 ++-- .../eplb/core/eplb_device_transfer_loader.py | 47 +- vllm_ascend/eplb/core/eplb_utils.py | 32 +- vllm_ascend/eplb/core/eplb_worker.py | 139 ++--- .../eplb/core/policy/policy_abstract.py | 1 - .../eplb/core/policy/policy_default_eplb.py | 178 +++--- .../eplb/core/policy/policy_factory.py | 19 +- .../eplb/core/policy/policy_flashlb.py | 127 ++-- vllm_ascend/eplb/core/policy/policy_random.py | 1 - .../eplb/core/policy/policy_swift_balancer.py | 543 +++++++++--------- vllm_ascend/eplb/eplb_updator.py | 62 +- vllm_ascend/eplb/utils.py | 23 +- .../netloader/executor/elastic_load.py | 46 +- .../netloader/executor/netloader_pg.py | 44 +- .../netloader/interaction/elastic.py | 166 +++--- vllm_ascend/model_loader/netloader/load.py | 19 +- .../model_loader/netloader/netloader.py | 174 +++--- vllm_ascend/model_loader/netloader/utils.py | 15 +- vllm_ascend/patch/platform/__init__.py | 5 +- .../patch/platform/patch_balance_schedule.py | 192 +++---- .../patch/platform/patch_ec_connector.py | 17 +- .../patch/platform/patch_mamba_config.py | 30 +- .../platform/patch_multiproc_executor.py | 40 +- .../patch/platform/patch_sched_yield.py | 7 +- 26 files changed, 894 insertions(+), 1148 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b78e5d89..3f053ec5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,10 +61,6 @@ exclude = [ "vllm_ascend/distributed/kv_transfer/utils/**", "vllm_ascend/kv_offload/**", "vllm_ascend/lora/**", - # (6) - "vllm_ascend/eplb/**", - "vllm_ascend/model_loader/netloader/**", - "vllm_ascend/patch/**", # (7) "vllm_ascend/quantization/**", "vllm_ascend/sample/*.py", @@ -92,6 +88,7 @@ exclude = [ "vllm_ascend/distributed/parallel_state.py", "vllm_ascend/distributed/utils.py", "vllm_ascend/xlite/*.py", + "vllm_ascend/patch/worker/patch_*.py", # (11) "vllm_ascend/ops/fused_moe/**", ] diff --git a/vllm_ascend/eplb/adaptor/abstract_adaptor.py b/vllm_ascend/eplb/adaptor/abstract_adaptor.py index a8c6a035..ff58e170 100644 --- a/vllm_ascend/eplb/adaptor/abstract_adaptor.py +++ b/vllm_ascend/eplb/adaptor/abstract_adaptor.py @@ -19,8 +19,7 @@ from abc import abstractmethod from typing import Any -class EplbAdaptor(): - +class EplbAdaptor: def __init__(self, **args): pass @@ -29,12 +28,9 @@ class EplbAdaptor(): raise NotImplementedError @abstractmethod - def do_update_expert_map(self, layer_id: Any, - updated_expert_map: Any) -> Any: + def do_update_expert_map(self, layer_id: Any, updated_expert_map: Any) -> Any: raise NotImplementedError @abstractmethod - def do_update_expert_weight(self, layer_id: Any, - local_expert_to_replace: Any, - buffer_tensor_id: Any) -> Any: + def do_update_expert_weight(self, layer_id: Any, local_expert_to_replace: Any, buffer_tensor_id: Any) -> Any: raise NotImplementedError diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 500c9d4b..8d718213 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -26,7 +26,6 @@ from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor class VllmEplbAdaptor(EplbAdaptor): - def __init__(self, model, **args): super().__init__(**args) self.model = model @@ -36,33 +35,37 @@ class VllmEplbAdaptor(EplbAdaptor): self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0) self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers - for i in range(self.num_dense_layers, - self.model.config.num_hidden_layers): - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \ - self.model.model.layers[i].mlp.experts.w13_weight_list - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \ - self.model.model.layers[i].mlp.experts.w2_weight_list - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \ + for i in range(self.num_dense_layers, self.model.config.num_hidden_layers): + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = self.model.model.layers[ + i + ].mlp.experts.w13_weight_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = self.model.model.layers[ + i + ].mlp.experts.w2_weight_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = ( self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ + ) + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = ( self.model.model.layers[i].mlp.experts.w2_weight_scale_list - # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here + ) + # TODO: init self.expert_weight_names depending on different model types. + # Only deepseek v3 w8a8 and qwen3-moe is supported here if self.model.quant_config is not None: self.expert_weight_names = [ - "w13_weight_list", "w2_weight_list", - "w13_weight_scale_fp32_list", "w13_weight_offset", - "w2_weight_scale_list", "w2_weight_offset" + "w13_weight_list", + "w2_weight_list", + "w13_weight_scale_fp32_list", + "w13_weight_offset", + "w2_weight_scale_list", + "w2_weight_offset", ] else: self.expert_weight_names = ["w13_weight", "w2_weight"] - self.expert_map_per_layer_cpu = dict( - ) # copy of expert map on CPU to avoid device synchronize frequently + self.expert_map_per_layer_cpu = dict() # copy of expert map on CPU to avoid device synchronize frequently num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts - self.buffer_tensor_list: list[list[Any]] = [ - [] for _ in range(num_buffer_tensor) - ] + self.buffer_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)] self.init_buffer_tensor(num_buffer_tensor) self.expert_param_per_layer = dict() @@ -70,18 +73,15 @@ class VllmEplbAdaptor(EplbAdaptor): self.log2phy_map_per_layer = dict() for layer_idx in range(self.num_moe_layers): - self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \ - self.model.get_log2phy_map(self.num_dense_layers + layer_idx) + self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = self.model.get_log2phy_map( + self.num_dense_layers + layer_idx + ) def init_buffer_tensor(self, num_buffer_tensor): for buffer_id in range(num_buffer_tensor): for name in self.expert_weight_names: - complete_name = "model.layers." + str( - self.num_dense_layers) + ".mlp.experts." + name - if name in [ - "w13_weight_list", "w2_weight_list", - "w13_weight_scale_fp32_list", "w2_weight_scale_list" - ]: + complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name + if name in ["w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w2_weight_scale_list"]: expert_tensor = self.param_dict[complete_name][0] expert_tensor = expert_tensor.clone() else: @@ -99,19 +99,20 @@ class VllmEplbAdaptor(EplbAdaptor): per_expert_param = list() for name in self.expert_weight_names: if name in [ - "w13_weight_list", "w2_weight_list", - "w13_weight_scale_fp32_list", - "w2_weight_scale_list" + "w13_weight_list", + "w2_weight_list", + "w13_weight_scale_fp32_list", + "w2_weight_scale_list", ]: per_expert_param.append( - self.param_dict["model.layers." + str(layer_idx) + - ".mlp.experts." + - name][local_expert_id]) + self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][local_expert_id] + ) else: per_expert_param.append( - self.param_dict["model.layers." + str(layer_idx) + - ".mlp.experts." + - name][0].data[local_expert_id]) + self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][0].data[ + local_expert_id + ] + ) self.expert_param_per_layer[layer_idx].append(per_expert_param) def get_rank_expert_workload(self) -> torch.Tensor: @@ -123,26 +124,18 @@ class VllmEplbAdaptor(EplbAdaptor): num_local_experts = expert_maps.max() + 1 expert_maps_list = expert_maps.tolist() - record: dict[str, Any] = { - "moe_layer_count": len(expert_maps_list), - "layer_list": [] - } + record: dict[str, Any] = {"moe_layer_count": len(expert_maps_list), "layer_list": []} for layer_idx, layer_data in enumerate(expert_maps_list): layer_record: dict[str, Any] = { "layer_id": layer_idx, "device_count": len(layer_data), - "device_list": [] + "device_list": [], } for device_idx, experts in enumerate(layer_data): - placement = [ - experts.index(i) for i in range(num_local_experts) - ] - device_record = { - "device_id": device_idx, - "device_expert": placement - } + placement = [experts.index(i) for i in range(num_local_experts)] + device_record = {"device_id": device_idx, "device_expert": placement} layer_record["device_list"].append(device_record) record["layer_list"].append(layer_record) @@ -153,11 +146,10 @@ class VllmEplbAdaptor(EplbAdaptor): def do_update_expert_map(self, layer_id, updated_expert_map): self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) - def do_update_expert_weight(self, layer_id, local_expert_to_replace, - buffer_tensor_id): + def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id): for expert_tensor, buffer_tensor in zip( - self.expert_param_per_layer[layer_id][local_expert_to_replace], - self.buffer_tensor_list[buffer_tensor_id]): + self.expert_param_per_layer[layer_id][local_expert_to_replace], self.buffer_tensor_list[buffer_tensor_id] + ): expert_tensor.copy_(buffer_tensor) logger.debug(f"Expert tensor shape is :{expert_tensor.shape}") @@ -168,10 +160,8 @@ class VllmEplbAdaptor(EplbAdaptor): def get_global_expert_map(self): all_layer_global_expert_map = [] for layer_id in range(self.num_moe_layers): - map_cpu = self.model.model.layers[ - self.num_dense_layers + layer_id].mlp.experts.global_expert_map.cpu() + map_cpu = self.model.model.layers[self.num_dense_layers + layer_id].mlp.experts.global_expert_map.cpu() all_layer_global_expert_map.append(map_cpu) - self.expert_map_per_layer_cpu[self.num_dense_layers + - layer_id] = map_cpu[self.rank_id] + self.expert_map_per_layer_cpu[self.num_dense_layers + layer_id] = map_cpu[self.rank_id] return torch.stack(all_layer_global_expert_map) diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index ba24c88c..345a0ee1 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -27,7 +27,6 @@ class ExpertWeightUpdateState(Enum): class D2DExpertWeightLoader: - def __init__(self): self.comm_op_list = None self.updated_expert_map = None @@ -40,14 +39,10 @@ class D2DExpertWeightLoader: def set_adator(self, eplb_adaptor): self.eplb_adaptor = eplb_adaptor - def generate_expert_d2d_transfer_task(self, expert_send_info, - expert_recv_info, updated_expert_map, - layer_id): + def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info, updated_expert_map, layer_id): # When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task if self.state != ExpertWeightUpdateState.WAITING: - logger.warning_once( - "current d2d weight update tasks are on-going, cannot accept new weight update task" - ) + logger.warning_once("current d2d weight update tasks are on-going, cannot accept new weight update task") return self.updated_expert_map = updated_expert_map @@ -56,25 +51,16 @@ class D2DExpertWeightLoader: self.comm_op_list = [] for send_info in expert_send_info: dst_rank, global_expert_id_to_send = send_info - local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[ - layer_id][global_expert_id_to_send].item() - for src_tensor in self.eplb_adaptor.expert_param_per_layer[ - layer_id][local_expert_id]: - self.comm_op_list.append( - dist.P2POp(dist.isend, src_tensor, dst_rank)) + local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item() + for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]: + self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) - buffer_tensor_id = 0 - for recv_info in expert_recv_info: + for buffer_tensor_id, recv_info in enumerate(expert_recv_info): recv_rank, global_expert_id_to_recv = recv_info - for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[ - buffer_tensor_id]: - self.comm_op_list.append( - dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) - local_expert_to_replace = self.updated_expert_map[ - global_expert_id_to_recv].item() - self.recv_expert_list.append( - (local_expert_to_replace, buffer_tensor_id)) - buffer_tensor_id += 1 + for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]: + self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) + local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item() + self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id)) self.state = ExpertWeightUpdateState.READY @@ -106,23 +92,18 @@ class D2DExpertWeightLoader: self.comm_op_list = None # update expert_map - self.eplb_adaptor.do_update_expert_map(self.layer_id, - self.updated_expert_map) + self.eplb_adaptor.do_update_expert_map(self.layer_id, self.updated_expert_map) # update log2phy_map - self.eplb_adaptor.do_update_log2phy_map(self.layer_id, - self.updated_log2phy_map) + self.eplb_adaptor.do_update_log2phy_map(self.layer_id, self.updated_log2phy_map) # update expert weight buffer_tensor_id = 0 for recv_expert_info in self.recv_expert_list: local_expert_to_replace, buffer_tensor_id = recv_expert_info - self.eplb_adaptor.do_update_expert_weight(self.layer_id, - local_expert_to_replace, - buffer_tensor_id) + self.eplb_adaptor.do_update_expert_weight(self.layer_id, local_expert_to_replace, buffer_tensor_id) - logger.debug( - f"[EPLB] finished update expert weight for layer: {self.layer_id}") + logger.debug(f"[EPLB] finished update expert weight for layer: {self.layer_id}") self.recv_expert_list = [] self.updated_expert_map = None diff --git a/vllm_ascend/eplb/core/eplb_utils.py b/vllm_ascend/eplb/core/eplb_utils.py index 88032d09..db5e31b5 100644 --- a/vllm_ascend/eplb/core/eplb_utils.py +++ b/vllm_ascend/eplb/core/eplb_utils.py @@ -25,7 +25,7 @@ from vllm.logger import logger def expert_file_to_tensor(expert_map_path, layer_id): - with open(expert_map_path, "r") as f: + with open(expert_map_path) as f: data = json.load(f) physical_count = 0 device_data = [] @@ -61,38 +61,32 @@ def init_eplb_config(eplb_config, layer_id, moe_config): eplb_enable = eplb_config.dynamic_eplb n_redundant = eplb_config.num_redundant_experts if eplb_enable else 0 if expert_map_path: - if not (os.path.exists(expert_map_path) - and os.access(expert_map_path, os.R_OK)): + if not (os.path.exists(expert_map_path) and os.access(expert_map_path, os.R_OK)): raise ValueError("Invalid EPLB path") eplb_enable = True - global_placement, physical_count = expert_file_to_tensor( - expert_map_path, layer_id) + global_placement, physical_count = expert_file_to_tensor(expert_map_path, layer_id) if physical_count is not None: n_redundant = physical_count - n_experts if not moe_config.supports_eplb: - raise ValueError( - "Eplb supports only w8a8_dynamic quantization.") + raise ValueError("Eplb supports only w8a8_dynamic quantization.") else: eplb_enable = False if global_placement is None: - global_placement = generate_global_placement(n_experts, ep_size, - n_redundant) + global_placement = generate_global_placement(n_experts, ep_size, n_redundant) if ep_size == 1: assert not eplb_enable, "EPLB must used in expert parallelism." return None, None, None, n_redundant global_expert_map = [] for rankid in range(ep_size): - expert_map = torch.full((n_experts, ), -1, dtype=torch.int32) + expert_map = torch.full((n_experts,), -1, dtype=torch.int32) local_placement = global_placement[rankid] - expert_map[local_placement] = torch.arange(local_placement.shape[0], - dtype=torch.int32) + expert_map[local_placement] = torch.arange(local_placement.shape[0], dtype=torch.int32) global_expert_map.append(expert_map) if rankid == moe_config.ep_rank: local_expert_map = expert_map.npu() - log2phy = generate_log2phy_map( - global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None + log2phy = generate_log2phy_map(global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None return torch.stack(global_expert_map), local_expert_map, log2phy, n_redundant @@ -106,13 +100,15 @@ def generate_log2phy_map(global_expert_map, ep_rank): if val != -1: log2phy_map[idx].append(val + rankid * valid_count) - for key in log2phy_map.keys(): + for key in log2phy_map: num_of_duplications = len(log2phy_map[key]) log2phy_map[key] = log2phy_map[key][ep_rank % num_of_duplications] log2phy_map = torch.scatter( - torch.zeros(len(log2phy_map.keys()), dtype=torch.int32), 0, - torch.tensor(list(log2phy_map.keys()), dtype=torch.int64), - torch.tensor(list(log2phy_map.values()), dtype=torch.int32)) + torch.zeros(len(log2phy_map), dtype=torch.int32), + 0, + torch.tensor(list(log2phy_map), dtype=torch.int64), + torch.tensor(list(log2phy_map.values()), dtype=torch.int32), + ) return log2phy_map diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index c9f808ec..469c84fd 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -17,23 +17,18 @@ from multiprocessing import Process, Queue from typing import Any -import networkx as nx # type: ignore -import numpy as np import torch import torch.distributed as dist from vllm.logger import logger from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map -from vllm_ascend.eplb.core.policy.policy_factory import (DynamicConfig, - PolicyFactory) +from vllm_ascend.eplb.core.policy.policy_factory import DynamicConfig, PolicyFactory class EplbWorker: - def __init__(self, shared_dict, policy_type, enable_d2d: bool = True): self.policy_type = policy_type - self.policy = PolicyFactory.generate_policy(policy_type, - DynamicConfig()) + self.policy = PolicyFactory.generate_policy(policy_type, DynamicConfig()) self.shared_dict = shared_dict self.old_expert_maps = None self.enable_d2d = enable_d2d @@ -62,10 +57,8 @@ class EplbWorker: return # Get the updated expert table based on the workload information - old_placement = self.global2local(self.old_expert_maps, - self.num_local_experts) - _, _, new_placement = self.calculate_rebalance_experts( - load_info, old_placement) + old_placement = self.global2local(self.old_expert_maps, self.num_local_experts) + _, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement) if not torch.is_tensor(new_placement): new_placement = torch.tensor(new_placement) @@ -73,8 +66,7 @@ class EplbWorker: new_expert_maps = self.local2global(new_placement) self.update_expert_map(new_expert_maps) - update_info = self.compose_expert_update_info_greedy( - new_expert_maps, self.old_expert_maps) + update_info = self.compose_expert_update_info_greedy(new_expert_maps, self.old_expert_maps) self.old_expert_maps = new_expert_maps logger.info("EPLB Process compute complete") @@ -88,11 +80,8 @@ class EplbWorker: for layer_id in range(num_layers): # check if any logical expert is not placed on any rank - if torch.unique(new_placement[layer_id]).numel() < torch.unique( - old_placement[layer_id]).numel(): - logger.error( - f"There exists expert not placed on any rank in layer {layer_id}" - ) + if torch.unique(new_placement[layer_id]).numel() < torch.unique(old_placement[layer_id]).numel(): + logger.error(f"There exists expert not placed on any rank in layer {layer_id}") new_placement[layer_id] = old_placement[layer_id] continue @@ -101,28 +90,26 @@ class EplbWorker: old_placement_check = old_placement[layer_id][rank_id] # check if same logical experts are placed on the same NPU - if new_placement_check.numel() != torch.unique( - new_placement_check).numel(): + if new_placement_check.numel() != torch.unique(new_placement_check).numel(): logger.error( - f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid" + "Replicated experts are placed on the same NPU; expert placement on " + f"layer {layer_id}, rank {rank_id} is invalid" ) new_placement[layer_id] = old_placement[layer_id] break # check if there is any experts movement inside one NPU - expert_not_move = torch.isin(new_placement_check, - old_placement_check) - if not torch.equal(new_placement_check[expert_not_move], - old_placement_check[expert_not_move]): + expert_not_move = torch.isin(new_placement_check, old_placement_check) + if not torch.equal(new_placement_check[expert_not_move], old_placement_check[expert_not_move]): logger.error( - f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid" + "There exists expert movement inside NPU; expert placement on " + f"layer {layer_id}, rank {rank_id} is invalid" ) new_placement[layer_id] = old_placement[layer_id] break # TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases - def compose_expert_update_info_greedy(self, updated_expert_maps, - current_expert_maps): + def compose_expert_update_info_greedy(self, updated_expert_maps, current_expert_maps): num_layers = current_expert_maps.shape[0] for layer_id in range(num_layers): updated_expert_maps_this_layer = updated_expert_maps[layer_id] @@ -132,19 +119,23 @@ class EplbWorker: expert_recv_info_this_layer: dict[Any, Any] = {} # Guard Clause: if there is no expert weight update, avoid subsequent processing - if torch.equal(updated_expert_maps_this_layer, - current_expert_maps_this_layer): - yield (expert_send_info_this_layer, - expert_recv_info_this_layer, - updated_expert_maps_this_layer, layer_id) + if torch.equal(updated_expert_maps_this_layer, current_expert_maps_this_layer): + yield ( + expert_send_info_this_layer, + expert_recv_info_this_layer, + updated_expert_maps_this_layer, + layer_id, + ) # Parse expert_ids each rank needs to receive from other ranks - dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \ - & (updated_expert_maps_this_layer != -1)) + dst_rank_indices, experts_to_recv = torch.where( + (current_expert_maps_this_layer == -1) & (updated_expert_maps_this_layer != -1) + ) # Parse expert_ids each rank needs to send to other ranks - src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \ - & (updated_expert_maps_this_layer == -1)) + src_rank_indices, experts_to_send = torch.where( + (current_expert_maps_this_layer != -1) & (updated_expert_maps_this_layer == -1) + ) for idx in range(len(dst_rank_indices)): dst_rank_id = dst_rank_indices[idx].item() @@ -152,27 +143,27 @@ class EplbWorker: if dst_rank_id not in expert_recv_info_this_layer: expert_recv_info_this_layer[dst_rank_id] = [] - if not torch.isin(torch.tensor(expert_id), - experts_to_send).any(): + if not torch.isin(torch.tensor(expert_id), experts_to_send).any(): # if expert_id are not sent out from any npu, it will be copied from one npu holding this expert - candidate_src_rank_indices = torch.where( - current_expert_maps_this_layer[:, expert_id] != -1)[0] + candidate_src_rank_indices = torch.where(current_expert_maps_this_layer[:, expert_id] != -1)[0] else: - candidate_src_rank_indices = src_rank_indices[ - experts_to_send == expert_id] + candidate_src_rank_indices = src_rank_indices[experts_to_send == expert_id] - # TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node... + # TODO: improve selection criterion of NPU sending expert_id, + # considering intra-node or inter-node... src_rank_id = candidate_src_rank_indices[0].item() if src_rank_id not in expert_send_info_this_layer: expert_send_info_this_layer[src_rank_id] = [] - expert_send_info_this_layer[src_rank_id].append( - (dst_rank_id, expert_id)) - expert_recv_info_this_layer[dst_rank_id].append( - (src_rank_id, expert_id)) + expert_send_info_this_layer[src_rank_id].append((dst_rank_id, expert_id)) + expert_recv_info_this_layer[dst_rank_id].append((src_rank_id, expert_id)) - yield (expert_send_info_this_layer, expert_recv_info_this_layer, - updated_expert_maps_this_layer, layer_id) + yield ( + expert_send_info_this_layer, + expert_recv_info_this_layer, + updated_expert_maps_this_layer, + layer_id, + ) def calculate_rebalance_experts(self, load_info, old_placement): """ @@ -181,8 +172,7 @@ class EplbWorker: if self.old_expert_maps is None: return False, None, None - changed, priority, new_map = self.policy.rebalance_experts( - old_placement, load_info) + changed, priority, new_map = self.policy.rebalance_experts(old_placement, load_info) return changed, priority, new_map def get_init_expert_maps(self): @@ -199,19 +189,13 @@ class EplbWorker: return self.shared_dict.get("moe_load", None) def update_expert_map(self, expert_maps): - self.shared_dict["expert_maps"] = expert_maps - def global2local(self, placement: torch.Tensor, - E_local: int) -> tuple[torch.Tensor, torch.Tensor]: - + def global2local(self, placement: torch.Tensor, E_local: int) -> tuple[torch.Tensor, torch.Tensor]: L, G, _ = placement.shape device = placement.device - pt_local = torch.full((L, G, E_local), - fill_value=-1, - dtype=torch.long, - device=device) + pt_local = torch.full((L, G, E_local), fill_value=-1, dtype=torch.long, device=device) valid = placement >= 0 l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) @@ -223,7 +207,6 @@ class EplbWorker: return pt_local def local2global(self, placement_local: torch.Tensor) -> torch.Tensor: - L, G, E_local = placement_local.shape device = placement_local.device @@ -233,10 +216,7 @@ class EplbWorker: if E_global == 0: return torch.empty((L, G, 0), dtype=torch.long, device=device) - placement_global = torch.full((L, G, E_global), - fill_value=-1, - dtype=torch.long, - device=device) + placement_global = torch.full((L, G, E_global), fill_value=-1, dtype=torch.long, device=device) valid = placement_local >= 0 l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) @@ -257,11 +237,8 @@ class EplbWorker: layer_ids = [] for send_info, recv_info, new_expert_map, layer_id in update_info_generator: - - send_info_this_rank = send_info[ - self.rank_id] if self.rank_id in send_info else [] - recv_info_this_rank = recv_info[ - self.rank_id] if self.rank_id in recv_info else [] + send_info_this_rank = send_info.get(self.rank_id, []) + recv_info_this_rank = recv_info.get(self.rank_id, []) send_all.append(send_info_this_rank) recv_all.append(recv_info_this_rank) @@ -276,11 +253,7 @@ class EplbWorker: class EplbProcess: - - def __init__(self, - shared_dict, - policy_type: int = 0, - enable_d2d: bool = True): + def __init__(self, shared_dict, policy_type: int = 0, enable_d2d: bool = True): """ Args: shared_dict: Cross-process shared dict returned by Manager().dict() @@ -294,12 +267,12 @@ class EplbProcess: self.block_update_q: Queue[Any] = Queue(maxsize=1) # Create EplbWorker instance - self.worker = EplbWorker(self.shared_dict, self.policy_type, - self.enable_d2d) + self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d) def worker_process(self, planner_q, block_update_q): """ - Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. + Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, + call do_update, then notify main process update is complete. """ while True: try: @@ -314,17 +287,17 @@ class EplbProcess: break except Exception as e: - logger.warning(f"[EPLB subprocess Exiting due to error: {e}", - exc_info=True) + logger.warning( + f"[EPLB subprocess exiting due to error: {e}]", + exc_info=True, + ) break def _launch_process(self): """ Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). """ - proc = Process(target=self.worker_process, - args=(self.planner_q, self.block_update_q), - daemon=True) + proc = Process(target=self.worker_process, args=(self.planner_q, self.block_update_q), daemon=True) proc.start() return proc diff --git a/vllm_ascend/eplb/core/policy/policy_abstract.py b/vllm_ascend/eplb/core/policy/policy_abstract.py index 8ef58e29..ce2a764c 100644 --- a/vllm_ascend/eplb/core/policy/policy_abstract.py +++ b/vllm_ascend/eplb/core/policy/policy_abstract.py @@ -12,7 +12,6 @@ class DynamicConfig: class EplbPolicy: - def __init__(self, config: DynamicConfig): self.config = config diff --git a/vllm_ascend/eplb/core/policy/policy_default_eplb.py b/vllm_ascend/eplb/core/policy/policy_default_eplb.py index a43e1cb0..5348f301 100644 --- a/vllm_ascend/eplb/core/policy/policy_default_eplb.py +++ b/vllm_ascend/eplb/core/policy/policy_default_eplb.py @@ -25,13 +25,11 @@ class DynamicTable: class DefaultEplb(EplbPolicy): - def __init__(self, config: DynamicConfig): super().__init__(config) @staticmethod - def add_redundant(current_expert_table, expert_workload, - num_original_expert): + def add_redundant(current_expert_table, expert_workload, num_original_expert): layer_num, npu_num, experts_per_npu = expert_workload.shape workload_new = np.zeros((layer_num, num_original_expert)) for layer_idx in range(layer_num): @@ -40,31 +38,24 @@ class DefaultEplb(EplbPolicy): workload_layer = expert_workload[layer_idx].copy() for npu_idx in range(npu_num): for expert_idx in range(experts_per_npu): - workload_dict[placement_layer[npu_idx][ - expert_idx]] += workload_layer[npu_idx][expert_idx] + workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx] for expert_idx in range(num_original_expert): workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] return workload_new @staticmethod # Split hot (high-load) experts into redundant experts - def original_compute_balanced_pack_redundancy(origin_weights, card_num, - num_redundancy_expert): + def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): # Step 1: Sort the items by weight in descending order (we are sorting by weight now) # Sort based on the second element (the second value of each tuple) route_expert_num = len(origin_weights) - route_expert_redundancy: list[list[int]] = [ - [] for _ in range(route_expert_num) - ] + route_expert_redundancy: list[list[int]] = [[] for _ in range(route_expert_num)] for i in range(num_redundancy_expert): - sorted_indices = np.argsort([t[1] for t in origin_weights], - kind='stable')[::-1] + sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1] weights = [origin_weights[idx] for idx in sorted_indices] - tmp_raw_weight = weights[0][1] * ( - len(route_expert_redundancy[weights[0][0]]) + 1) + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) route_expert_redundancy[weights[0][0]].append(route_expert_num + i) - avg_weight = tmp_raw_weight / ( - len(route_expert_redundancy[weights[0][0]]) + 1) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) weights[0] = (weights[0][0], avg_weight) origin_weights = weights @@ -93,8 +84,7 @@ class DefaultEplb(EplbPolicy): box_counts[index] += 1 index += 1 - sorted_indices = np.argsort([t[1] for t in origin_weights], - kind='stable')[::-1] + sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1] origin_weights = [origin_weights[idx] for idx in sorted_indices] # Step 4: Distribute items into boxes based on weight for item_id, weight in origin_weights: @@ -104,11 +94,8 @@ class DefaultEplb(EplbPolicy): if item_id in boxes[i]: continue # Only choose boxes that still have space (box_counts[i] < items_per_box) - if box_counts[i] < items_per_box or (box_counts[i] - == items_per_box - and remaining_items > 0): - if min_box_index == -1 or box_weights[i] < box_weights[ - min_box_index]: + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: min_box_index = i # Place the item (id) into the selected box @@ -118,40 +105,35 @@ class DefaultEplb(EplbPolicy): box_counts[min_box_index] += 1 # If there's an imbalance in the remaining items, reduce the "remaining_items" counter - if box_counts[min_box_index] == (items_per_box + - 1) and remaining_items > 0: + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: remaining_items -= 1 # Step 5: Output each box's contents and total weight result = [] for i in range(card_num): - result.append({ - "box_index": i + 1, - "items": boxes[i], # List of item IDs in the box - "weight": boxes_weights[i], - "total_weight": box_weights[i], # Total weight in this box - "item_count": box_counts[i] # Number of items in the box - }) + result.append( + { + "box_index": i + 1, + "items": boxes[i], # List of item IDs in the box + "weight": boxes_weights[i], + "total_weight": box_weights[i], # Total weight in this box + "item_count": box_counts[i], # Number of items in the box + } + ) return result, boxes # Split hot (high-load) experts into redundant experts @staticmethod - def compute_balanced_pack_redundancy(origin_weights, card_num, - num_redundancy_expert): + def compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): route_expert_num = len(origin_weights) - route_expert_redundancy: list[list[int]] = [ - [] for _ in range(route_expert_num) - ] + route_expert_redundancy: list[list[int]] = [[] for _ in range(route_expert_num)] for i in range(num_redundancy_expert): - sorted_indices = np.argsort([t[1] for t in origin_weights], - kind='stable')[::-1] + sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1] weights = [origin_weights[idx] for idx in sorted_indices] - tmp_raw_weight = weights[0][1] * ( - len(route_expert_redundancy[weights[0][0]]) + 1) + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) route_expert_redundancy[weights[0][0]].append(route_expert_num + i) - avg_weight = tmp_raw_weight / ( - len(route_expert_redundancy[weights[0][0]]) + 1) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) weights[0] = (weights[0][0], avg_weight) origin_weights = weights @@ -166,7 +148,7 @@ class DefaultEplb(EplbPolicy): box_weights = [0] * card_num box_counts = [0] * card_num - all_weights = np.zeros((expert_num, ), dtype='object') + all_weights = np.zeros((expert_num,), dtype="object") all_weights[:route_expert_num] = origin_weights index = route_expert_num @@ -178,17 +160,13 @@ class DefaultEplb(EplbPolicy): all_weights[index] = (item, weight) index += 1 - sorted_indices = np.argsort([t[1] for t in all_weights], - kind='stable')[::-1] + sorted_indices = np.argsort([t[1] for t in all_weights], kind="stable")[::-1] all_weights = [all_weights[idx] for idx in sorted_indices] for item_id, weight in all_weights: min_box_index = -1 for i in range(card_num): - if box_counts[i] < items_per_box or (box_counts[i] - == items_per_box - and remaining_items > 0): - if min_box_index == -1 or box_weights[i] < box_weights[ - min_box_index]: + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: if item_id not in boxes[i]: min_box_index = i @@ -197,19 +175,20 @@ class DefaultEplb(EplbPolicy): box_weights[min_box_index] += weight box_counts[min_box_index] += 1 - if box_counts[min_box_index] == (items_per_box + - 1) and remaining_items > 0: + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: remaining_items -= 1 result = [] for i in range(card_num): - result.append({ - "box_index": i + 1, - "items": boxes[i], - "weight": boxes_weights[i], - "total_weight": box_weights[i], - "item_count": box_counts[i] - }) + result.append( + { + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i], + } + ) return result, boxes @@ -232,11 +211,8 @@ class DefaultEplb(EplbPolicy): for item_id, weight in weights: min_box_index = -1 for i in range(card_num): - if box_counts[i] < items_per_box or (box_counts[i] - == items_per_box - and remaining_items > 0): - if min_box_index == -1 or box_weights[i] < box_weights[ - min_box_index]: + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: min_box_index = i boxes[min_box_index].append(item_id) @@ -244,19 +220,20 @@ class DefaultEplb(EplbPolicy): box_weights[min_box_index] += weight box_counts[min_box_index] += 1 - if box_counts[min_box_index] == (items_per_box + - 1) and remaining_items > 0: + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: remaining_items -= 1 result = [] for i in range(card_num): - result.append({ - "box_index": i + 1, - "items": boxes[i], - "weight": boxes_weights[i], - "total_weight": box_weights[i], - "item_count": box_counts[i] - }) + result.append( + { + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i], + } + ) return result, boxes @@ -274,16 +251,11 @@ class DefaultEplb(EplbPolicy): return max_heat_per_layer @staticmethod - def constraint_expert_local_exchange(current_expert_table, - global_deployment): + def constraint_expert_local_exchange(current_expert_table, global_deployment): for layer_id in range(len(global_deployment)): for card_id in range(len(global_deployment[layer_id])): - current_list = [ - int(x) for x in current_expert_table[layer_id][card_id] - ] - new_list = [ - int(x) for x in global_deployment[layer_id][card_id] - ] + current_list = [int(x) for x in current_expert_table[layer_id][card_id]] + new_list = [int(x) for x in global_deployment[layer_id][card_id]] num = len(new_list) new_index = [-1] * num @@ -293,8 +265,7 @@ class DefaultEplb(EplbPolicy): for i in range(num): flag = True for j in range(num): - if new_list[i] == current_list[j] and new_index[ - j] == -1: + if new_list[i] == current_list[j] and new_index[j] == -1: new_index[j] = 0 new_result[j] = current_list[j] flag = False @@ -313,7 +284,6 @@ class DefaultEplb(EplbPolicy): return global_deployment def rebalance_experts(self, current_expert_table, expert_workload): - info = DynamicTable() info.workload_table = np.array(expert_workload) info.placement_table = np.array(current_expert_table) @@ -324,17 +294,15 @@ class DefaultEplb(EplbPolicy): expert_ids, counts = np.unique(row, return_counts=True) num_redundancy_expert = self.get_redundant_num(num_npus, counts) num_original_expert = len(expert_ids) - layer_workloads = self.add_redundant(info.placement_table, - info.workload_table, - num_original_expert) - max_heat_per_layer_before = self.calculate_max_heat_per_layer( - info.workload_table, layer_num) + layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) npu_heat_all_origin = sum(max_heat_per_layer_before) # Perform load balancing and deploy redundant experts layer_num = layer_workloads.shape[0] expert_num = layer_workloads.shape[1] - # Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards + # Validate that the number of experts, number of cards, and number of redundant experts + # do not exceed the number of cards. if num_original_expert != expert_num: raise ValueError( f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}" @@ -345,38 +313,35 @@ class DefaultEplb(EplbPolicy): if num_npus < num_redundancy_expert: raise ValueError( - f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}" + "the number of NPUs " + f"{num_npus} must be greater than or equal to the number of redundant experts " + f"{num_redundancy_expert}" ) # Number of experts deployed on each card includes one redundant expert - global_deployment: list[list[list[int]]] = [[[] - for _ in range(num_npus)] - for _ in range(layer_num)] + global_deployment: list[list[list[int]]] = [[[] for _ in range(num_npus)] for _ in range(layer_num)] # Iterate to obtain the placement strategy for each layer, taking computational balance into account max_heat_per_layer_after = np.zeros([layer_num]) for layer in range(layer_num): # Get the expert IDs and their corresponding workloads for the current layer; # workloads need to be normalized, and one redundant expert is added per card - weights = np.zeros((expert_num, ), dtype='object') - for expert_id, workload_weight in enumerate( - layer_workloads[layer]): + weights = np.zeros((expert_num,), dtype="object") + for expert_id, workload_weight in enumerate(layer_workloads[layer]): weights[expert_id] = (expert_id, workload_weight) # Obtain the globally balanced placement strategy for each layer result, layer_deployment = self.original_compute_balanced_pack_redundancy( - weights, num_npus, num_redundancy_expert) + weights, num_npus, num_redundancy_expert + ) global_deployment[layer] = layer_deployment - max_heat_per_layer_after[layer] = max( - result, key=lambda x: x['total_weight'])['total_weight'] + max_heat_per_layer_after[layer] = max(result, key=lambda x: x["total_weight"])["total_weight"] - new_global_deployment = self.constraint_expert_local_exchange( - current_expert_table, global_deployment) + new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment) # Obtain the priority of each layer layer_changed_ratio = [] for layer_idx in range(layer_num): - layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / - max_heat_per_layer_before[layer_idx]) + layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx]) per_layer_priority = np.argsort(layer_changed_ratio) npu_heat_all_after = sum(max_heat_per_layer_after) @@ -385,5 +350,4 @@ class DefaultEplb(EplbPolicy): if npu_heat_all_after < 0.95 * npu_heat_all_origin: change = 1 - return change, per_layer_priority, np.array( - new_global_deployment).tolist() + return change, per_layer_priority, np.array(new_global_deployment).tolist() diff --git a/vllm_ascend/eplb/core/policy/policy_factory.py b/vllm_ascend/eplb/core/policy/policy_factory.py index d21b2872..9b9f8e22 100644 --- a/vllm_ascend/eplb/core/policy/policy_factory.py +++ b/vllm_ascend/eplb/core/policy/policy_factory.py @@ -2,29 +2,26 @@ # Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory. from .policy_abstract import DynamicConfig, EplbPolicy from .policy_default_eplb import DefaultEplb -from .policy_swift_balancer import SwiftBalanceEplb from .policy_flashlb import FlashLB, warm_up from .policy_random import RandomLoadBalance +from .policy_swift_balancer import SwiftBalanceEplb class PolicyFactory: - @staticmethod def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy: policy = { # Constraint applying Dynamic EPLB policy V2: # If there exists redundant expert: # only one redundant expert can be placed in one NPU and its physical expert index must be 0 - # Applying greedy d2d expert weight update composing - 0: - RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3 - 1: - DefaultEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load - 2: - SwiftBalanceEplb, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle - 3: - FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment + 0: RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3 + 1: DefaultEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load + # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle + 2: SwiftBalanceEplb, + # FlashLB EPLB policy: expert replacement based on Joint Optimization, + # Multi-Shot Enhancement and Incremental Adjustment + 3: FlashLB, } policy_class = policy.get(policy_type, RandomLoadBalance) policy_instance = policy_class(config) diff --git a/vllm_ascend/eplb/core/policy/policy_flashlb.py b/vllm_ascend/eplb/core/policy/policy_flashlb.py index 7a13bee2..3ceefa73 100644 --- a/vllm_ascend/eplb/core/policy/policy_flashlb.py +++ b/vllm_ascend/eplb/core/policy/policy_flashlb.py @@ -3,7 +3,6 @@ import logging from collections import deque -from typing import Dict import numpy as np import torch @@ -45,8 +44,7 @@ def compute_piece_counts(X, P, stage_weights): secv = unit[i, idx2] alt = X[i, idx1] / (pieces[idx1] + 1) delta = origin - (alt if alt > secv else secv) - deltas[idx1] += delta * stage_weights[i] if np.any( - delta) != 0 else stage_weights[i] + deltas[idx1] += delta * stage_weights[i] if np.any(delta) != 0 else stage_weights[i] max_idx = np.argmax(deltas) pieces[max_idx] += 1 @@ -157,9 +155,7 @@ def jsq_placement(X, pieces, M, stage_weights): # Get elements already in current column current_rank_elements = set(deployment[rank, :]) # Filter elements from range(N) not in current column - available = [ - x for x in range(N) if x not in current_rank_elements - ] + available = [x for x in range(N) if x not in current_rank_elements] # Randomly select an available element to fill if len(available) > 0: rand_idx = np.random.randint(0, len(available)) @@ -187,8 +183,7 @@ def slice_values(X, pieces): @njit -def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, - simulated_deployment, stage_weights): +def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, simulated_deployment, stage_weights): n_stage, N = X.shape num_group = P // M @@ -207,12 +202,11 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, flat_deployment = simulated_deployment.reshape(-1) simulated_load = np.zeros(M, dtype=np.float32) for i in range(flat_deployment.shape[0]): - simulated_load[i // (flat_deployment.shape[0] // - M)] += unit_load[flat_deployment[i]] + simulated_load[i // (flat_deployment.shape[0] // M)] += unit_load[flat_deployment[i]] slice_vals = slice_values(X_all, simulated_pieces) sorted_slices = np.sort(slice_vals)[::-1] - simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M + simulated_slopes = (sorted_slices[: -M + 1] - sorted_slices[M - 1 :]) / M cumulative_slices_used = np.zeros(N, dtype=np.int32) acc = 0 @@ -230,8 +224,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, slices_used_per_group = np.zeros(num_group, dtype=np.int32) slices_used_per_group[0] = group_boundary_indices[0] for i in range(1, num_group): - slices_used_per_group[ - i] = group_boundary_indices[i] - group_boundary_indices[i - 1] + slices_used_per_group[i] = group_boundary_indices[i] - group_boundary_indices[i - 1] slices_used_per_group = M - slices_used_per_group loads = np.zeros(M, dtype=np.float32) @@ -240,7 +233,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, current_idx = 0 for g in range(num_group): - window = X_sorted[:, current_idx:current_idx + 2 * M] + window = X_sorted[:, current_idx : current_idx + 2 * M] low = max(0, current_idx + M - N) high = min(num_remain_slice, M - 1) @@ -248,8 +241,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, mid = int((high + low) // 2) keep = M - mid current_group = window[:, :keep] - current_pieces = compute_piece_counts(current_group, M, - stage_weights) + current_pieces = compute_piece_counts(current_group, M, stage_weights) current_pieces = np.maximum(current_pieces, 1) current_slice = slice_values(current_group.sum(0), current_pieces) current_slice_sorted = np.sort(current_slice) @@ -257,8 +249,7 @@ def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, current_max: np.float32 = np.max(current_loads) current_min: np.float32 = np.min(current_loads) current_slope = (current_max - current_min) / M - next_slope: np.float32 = np.max(simulated_slopes[current_idx + - keep:]) + next_slope: np.float32 = np.max(simulated_slopes[current_idx + keep :]) if abs(current_slope) > abs(next_slope): low = mid @@ -327,9 +318,7 @@ def auto_fix_new_placement(old_placement, new_placement): old_row = old_placement[rank_id] new_row = new_placement[rank_id] - index_array = np.full((max_expert + 1, num_experts), - -1, - dtype=np.int32) + index_array = np.full((max_expert + 1, num_experts), -1, dtype=np.int32) count_array = np.zeros(max_expert + 1, dtype=np.int32) for idx in range(num_experts): @@ -387,21 +376,17 @@ def auto_fix_new_placement(old_placement, new_placement): class FlashLB(EplbPolicy): - def __init__(self, config: DynamicConfig): super().__init__(config) - self.par_history: Dict[int, float] = {} - self.hotness_window: Dict[int, deque[float]] = {} - self.max_stage_window = (config.max_stage_window if hasattr( - config, "max_stage_window") else 1) + self.par_history: dict[int, float] = {} + self.hotness_window: dict[int, deque[float]] = {} + self.max_stage_window = config.max_stage_window if hasattr(config, "max_stage_window") else 1 self.buffer_expert_layer_num = ( - config.buffer_expert_layer_num if hasattr( - config, "buffer_expert_layer_num") else 58) - self.threshold_ratio = (config.threshold_ratio if hasattr( - config, "threshold_ratio") else 0) + config.buffer_expert_layer_num if hasattr(config, "buffer_expert_layer_num") else 58 + ) + self.threshold_ratio = config.threshold_ratio if hasattr(config, "threshold_ratio") else 0 - def compute_expert_hotness(self, num_of_expert: int, - deployment: np.ndarray, rank_load: np.ndarray): + def compute_expert_hotness(self, num_of_expert: int, deployment: np.ndarray, rank_load: np.ndarray): hotness = np.zeros(num_of_expert, dtype=rank_load.dtype) deployment_flat = deployment.ravel() rank_load_flat = rank_load.ravel() @@ -413,22 +398,14 @@ class FlashLB(EplbPolicy): if np.any(deployment < 0): raise ValueError("Deployment table contains negative values.") counts = np.bincount(deployment.reshape(-1), minlength=N) - unit_hotness = np.divide(hotness, - counts, - out=np.zeros_like(hotness, dtype=float), - where=counts != 0) + unit_hotness = np.divide(hotness, counts, out=np.zeros_like(hotness, dtype=float), where=counts != 0) stage_par = np.zeros(n_stage) for i in range(n_stage): stage_load = unit_hotness[i][deployment].sum(-1) stage_par[i] = stage_load.max() / stage_load.mean() return stage_par.mean() - def group_based_adaptive_bloating(self, - X, - P, - M, - stage_weights=None, - recorsive=False): + def group_based_adaptive_bloating(self, X, P, M, stage_weights=None, recorsive=False): n_stage, N = X.shape if stage_weights is None: stage_weights = np.ones(n_stage, dtype=np.float32) @@ -437,15 +414,10 @@ class FlashLB(EplbPolicy): ( simulated_deployment, simulated_pieces, - ) = self.group_based_adaptive_bloating(X, - P, - M, - stage_weights, - recorsive=False) + ) = self.group_based_adaptive_bloating(X, P, M, stage_weights, recorsive=False) else: simulated_pieces = compute_piece_counts(X, P, stage_weights) - simulated_deployment = jsq_placement(X, simulated_pieces, M, - stage_weights) + simulated_deployment = jsq_placement(X, simulated_pieces, M, stage_weights) pieces = group_based_adaptive_bloating_kernel( X.astype(np.float32), @@ -459,10 +431,7 @@ class FlashLB(EplbPolicy): deployment = jsq_placement(X, pieces, M, stage_weights) X_all = X.sum(0) - unit_load = np.divide(X_all, - pieces, - out=np.zeros_like(X_all, dtype=float), - where=pieces != 0) + unit_load = np.divide(X_all, pieces, out=np.zeros_like(X_all, dtype=float), where=pieces != 0) load = unit_load[deployment].sum(-1) sim_unit_load = X_all / simulated_pieces @@ -510,16 +479,13 @@ class FlashLB(EplbPolicy): def register_hotness(self, deployment, rank_load, num_layer, num_expert): for layer in range(num_layer): if layer not in self.hotness_window: - self.hotness_window[layer] = deque( - maxlen=self.max_stage_window) - hotness = self.compute_expert_hotness(num_expert, - deployment[layer], - rank_load[layer]) + self.hotness_window[layer] = deque(maxlen=self.max_stage_window) + hotness = self.compute_expert_hotness(num_expert, deployment[layer], rank_load[layer]) self.hotness_window[layer].append(hotness) def compress_by_avg_pooling_fast_nd(self, arr, m): n, d = arr.shape - idx = (np.arange(n) * m // n) + idx = np.arange(n) * m // n result = np.zeros((m, d)) counts = np.zeros((m, 1)) np.add.at(result, idx, arr) @@ -532,8 +498,7 @@ class FlashLB(EplbPolicy): expert_workload += 1 num_layer = expert_workload.shape[0] num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0] - self.register_hotness(current_deployment, expert_workload, num_layer, - num_expert) + self.register_hotness(current_deployment, expert_workload, num_layer, num_expert) new_deployment = current_deployment.copy() @@ -544,21 +509,17 @@ class FlashLB(EplbPolicy): for i, layer in enumerate(layers_need_update): hotness = np.array(self.hotness_window[layer]) if hotness.shape[0] > self.max_stage_window: - hotness = self.compress_by_avg_pooling_fast_nd( - hotness, self.max_stage_window) + hotness = self.compress_by_avg_pooling_fast_nd(hotness, self.max_stage_window) ( new_deployment[layer], new_par[i], current_par[i], - ) = self.rebalance_layer(current_deployment[layer], - hotness, - layer_id=layer) + ) = self.rebalance_layer(current_deployment[layer], hotness, layer_id=layer) priority = new_par / current_par priority_idx = np.argsort(priority) - priority_idx = priority_idx[priority[priority_idx] < - 1][:self.buffer_expert_layer_num] + priority_idx = priority_idx[priority[priority_idx] < 1][: self.buffer_expert_layer_num] if np.all(expert_workload == 1): for _, layer in enumerate(layers_need_update): @@ -572,16 +533,12 @@ class FlashLB(EplbPolicy): layers_need_update = priority_idx deployment = current_deployment for layer in layers_need_update: - deployment[layer] = auto_fix_new_placement( - current_deployment[layer], new_deployment[layer]) + deployment[layer] = auto_fix_new_placement(current_deployment[layer], new_deployment[layer]) return change, layers_need_update, deployment -def generate_layered_experts(num_layers=58, - layer_shape=(32, 9), - expert_min=0, - expert_max=255): +def generate_layered_experts(num_layers=58, layer_shape=(32, 9), expert_min=0, expert_max=255): """ Generate expert deployment matrix meeting the following conditions: - Total of num_layers layers @@ -598,32 +555,25 @@ def generate_layered_experts(num_layers=58, """ # 1. Basic parameter calculation expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255) - layer_total = layer_shape[0] * layer_shape[ - 1] # Total elements in a single layer: 32*9=288 + layer_total = layer_shape[0] * layer_shape[1] # Total elements in a single layer: 32*9=288 extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32 # 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts) assert layer_total >= expert_num, ( - f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, " - "cannot cover all experts") + f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, cannot cover all experts" + ) # 3. Generate layers one by one layers = [] for _ in range(num_layers): # 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included) - full_experts = torch.arange(expert_min, - expert_max + 1, - dtype=torch.int64) # shape (256,) + full_experts = torch.arange(expert_min, expert_max + 1, dtype=torch.int64) # shape (256,) # 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255) - extra_experts = torch.randint(expert_min, - expert_max + 1, - size=(extra_slots, ), - dtype=torch.int64) # shape (32,) + extra_experts = torch.randint(expert_min, expert_max + 1, size=(extra_slots,), dtype=torch.int64) # shape (32,) # 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer) - layer_flat = torch.cat([full_experts, extra_experts], - dim=0) # shape (288,) + layer_flat = torch.cat([full_experts, extra_experts], dim=0) # shape (288,) # Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues) shuffle_idx = torch.randperm(layer_flat.shape[0]) layer_shuffled = layer_flat[shuffle_idx] @@ -642,7 +592,6 @@ def warm_up(): exam_config.num_die_per_host = 16 algo = FlashLB(exam_config) # Generate target tensor - expert_tensor = generate_layered_experts(num_layers=58, - layer_shape=(32, 9)) + expert_tensor = generate_layered_experts(num_layers=58, layer_shape=(32, 9)) algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9))) diff --git a/vllm_ascend/eplb/core/policy/policy_random.py b/vllm_ascend/eplb/core/policy/policy_random.py index 558d6530..7bcdb3c0 100644 --- a/vllm_ascend/eplb/core/policy/policy_random.py +++ b/vllm_ascend/eplb/core/policy/policy_random.py @@ -9,7 +9,6 @@ random.seed(42) class RandomLoadBalance(EplbPolicy): - def __init__(self, config: DynamicConfig): super().__init__(config) diff --git a/vllm_ascend/eplb/core/policy/policy_swift_balancer.py b/vllm_ascend/eplb/core/policy/policy_swift_balancer.py index 5babf31d..9418ef86 100644 --- a/vllm_ascend/eplb/core/policy/policy_swift_balancer.py +++ b/vllm_ascend/eplb/core/policy/policy_swift_balancer.py @@ -16,7 +16,6 @@ class DynamicConfig: class EplbPolicy: - def __init__(self, config: DynamicConfig): self.config = config @@ -63,7 +62,6 @@ class DynamicTable: class SwiftBalanceEplb(EplbPolicy): - def __init__(self, config: DynamicConfig): super().__init__(config) @@ -89,8 +87,7 @@ class SwiftBalanceEplb(EplbPolicy): return a % b @staticmethod - def add_redundant(current_expert_table, expert_workload, - num_original_expert): + def add_redundant(current_expert_table, expert_workload, num_original_expert): layer_num, npu_num, experts_per_npu = expert_workload.shape workload_new = np.zeros((layer_num, num_original_expert)) for layer_idx in range(layer_num): @@ -99,8 +96,7 @@ class SwiftBalanceEplb(EplbPolicy): workload_layer = expert_workload[layer_idx].copy() for npu_idx in range(npu_num): for expert_idx in range(experts_per_npu): - workload_dict[placement_layer[npu_idx][ - expert_idx]] += workload_layer[npu_idx][expert_idx] + workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx] for expert_idx in range(num_original_expert): workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] return workload_new @@ -118,9 +114,7 @@ class SwiftBalanceEplb(EplbPolicy): max_heat_per_layer.append(np.max(npu_heats_now)) return max_heat_per_layer - def calculate_initial_imbalance(self, global_deployment, - new_layer_workloads): - + def calculate_initial_imbalance(self, global_deployment, new_layer_workloads): device_num = global_deployment.shape[1] layer_imbalance = [] expert_num = np.zeros_like(new_layer_workloads) @@ -136,56 +130,54 @@ class SwiftBalanceEplb(EplbPolicy): box_workload = 0 for expert_id in box: update_workload = self.safe_divide( - new_layer_workloads[layer_id][expert_id], - expert_num[layer_id][expert_id]) + new_layer_workloads[layer_id][expert_id], expert_num[layer_id][expert_id] + ) box_workload += update_workload total_workload += update_workload if cur_layer_max_workload < box_workload: cur_layer_max_workload = box_workload cur_layer_imbalance = self.safe_divide( - cur_layer_max_workload, - (self.safe_divide(total_workload, device_num))) + cur_layer_max_workload, (self.safe_divide(total_workload, device_num)) + ) layer_imbalance.append(cur_layer_imbalance) return layer_imbalance - def compute_redundant_assignments(self, base_experts, - num_redundant_experts, num_experts): - - redundant_assignments: list[list[int]] = [[] - for _ in range(num_experts)] + def compute_redundant_assignments(self, base_experts, num_redundant_experts, num_experts): + redundant_assignments: list[list[int]] = [[] for _ in range(num_experts)] current_weights = base_experts.copy() for i in range(num_redundant_experts): - sorted_indices = np.argsort([w for _, w in current_weights], - kind='stable')[::-1] + sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1] sorted_weights = [current_weights[i] for i in sorted_indices] target_expert = sorted_weights[0] expert_id, original_weight = target_expert current_redundancy = len(redundant_assignments[expert_id]) - new_avg_weight = self.safe_divide( - original_weight * (current_redundancy + 1), - (current_redundancy + 2)) + new_avg_weight = self.safe_divide(original_weight * (current_redundancy + 1), (current_redundancy + 2)) redundant_assignments[expert_id].append(num_experts + i) current_weights[sorted_indices[0]] = (expert_id, new_avg_weight) - sorted_indices = np.argsort([w for _, w in current_weights], - kind='stable')[::-1] + sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1] sorted_weights = [current_weights[i] for i in sorted_indices] return redundant_assignments, sorted_weights - def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos, - num_experts, num_exist_expert, - device_assignments, device_counts, - expert_from_device, - com_between_devices): - - current_weights = np.zeros((num_experts, ), dtype='object') + def repeat_compute_redundant_assignments( + self, + layer_workloads, + rendun_pos, + num_experts, + num_exist_expert, + device_assignments, + device_counts, + expert_from_device, + com_between_devices, + ): + current_weights = np.zeros((num_experts,), dtype="object") for expert_id, workload_weight in enumerate(layer_workloads): current_weights[expert_id] = (expert_id, workload_weight) @@ -195,8 +187,7 @@ class SwiftBalanceEplb(EplbPolicy): devices_with_slots.append(device_id) while devices_with_slots: - sorted_indices = np.argsort([w for _, w in current_weights], - kind='stable')[::-1] + sorted_indices = np.argsort([w for _, w in current_weights], kind="stable")[::-1] sorted_weights = [current_weights[i] for i in sorted_indices] for index, target_weight in enumerate(sorted_weights): @@ -210,17 +201,15 @@ class SwiftBalanceEplb(EplbPolicy): pos = rendun_pos[cur_device_id].pop() if len(rendun_pos[cur_device_id]) == 0: devices_with_slots = [ - device_id for device_id in devices_with_slots - if device_id != cur_device_id + device_id for device_id in devices_with_slots if device_id != cur_device_id ] device_assignments[cur_device_id][pos] = expert_id device_counts[cur_device_id] += 1 communication_box_index = expert_from_device[expert_id] - com_between_devices[cur_device_id][ - communication_box_index] = expert_id + com_between_devices[cur_device_id][communication_box_index] = expert_id new_weight = self.safe_divide( - (original_weight * num_exist_expert[expert_id]), - (num_exist_expert[expert_id] + 1)) + (original_weight * num_exist_expert[expert_id]), (num_exist_expert[expert_id] + 1) + ) sorted_weights[index] = (expert_id, new_weight) num_exist_expert[expert_id] += 1 redundancy_successful = True @@ -228,41 +217,31 @@ class SwiftBalanceEplb(EplbPolicy): if redundancy_successful: break - sorted_indices = np.argsort([id for id, _ in sorted_weights], - kind='stable') + sorted_indices = np.argsort([id for id, _ in sorted_weights], kind="stable") sorted_weights = [sorted_weights[i][1] for i in sorted_indices] return sorted_weights, device_assignments, device_counts, com_between_devices @staticmethod - def prepare_expert_list(base_experts, redundant_assignments, - num_redundant_experts): + def prepare_expert_list(base_experts, redundant_assignments, num_redundant_experts): redundant_expert_list = np.empty(num_redundant_experts, dtype=object) index = 0 num_experts = len(redundant_assignments) for expert_id in range(num_experts): for _ in redundant_assignments[expert_id]: - redundant_expert_list[index] = (expert_id, - next(w - for eid, w in base_experts - if eid == expert_id)) + redundant_expert_list[index] = (expert_id, next(w for eid, w in base_experts if eid == expert_id)) index += 1 - sorted_indices = np.argsort([w for _, w in redundant_expert_list], - kind='stable')[::-1] + sorted_indices = np.argsort([w for _, w in redundant_expert_list], kind="stable")[::-1] return [redundant_expert_list[i] for i in sorted_indices] @staticmethod - def non_redundant_expert_information(origin_deployment, updated_weights, - rendun_pos): - + def non_redundant_expert_information(origin_deployment, updated_weights, rendun_pos): device_num = len(origin_deployment) num_experts_per_device = origin_deployment.shape[1] - device_assignments = [[-1 for _ in range(num_experts_per_device)] - for _ in range(device_num)] - device_weights = [[0 for _ in range(num_experts_per_device)] - for _ in range(device_num)] + device_assignments = [[-1 for _ in range(num_experts_per_device)] for _ in range(device_num)] + device_weights = [[0 for _ in range(num_experts_per_device)] for _ in range(device_num)] device_loads = [0] * device_num device_counts = [0] * device_num @@ -272,8 +251,8 @@ class SwiftBalanceEplb(EplbPolicy): continue device_assignments[device_id][index] = expert_id cur_weight = next( - weight for expert_id_of_weight, weight in updated_weights - if expert_id_of_weight == expert_id) + weight for expert_id_of_weight, weight in updated_weights if expert_id_of_weight == expert_id + ) device_weights[device_id][index] = cur_weight device_loads[device_id] += cur_weight device_counts[device_id] += 1 @@ -292,19 +271,24 @@ class SwiftBalanceEplb(EplbPolicy): if num_all_experts[expert_id] == 0: cur_layer_workload.append(-1) else: - cur_layer_workload.append( - self.safe_divide(weight, num_all_experts[expert_id])) + cur_layer_workload.append(self.safe_divide(weight, num_all_experts[expert_id])) return cur_layer_workload, num_all_experts - def distribute_redun_experts(self, layer_workloads, device_assignments, - device_weights, device_loads, device_counts, - redundant_expert_list, expert_from_device, - num_experts, rendun_pos): - + def distribute_redun_experts( + self, + layer_workloads, + device_assignments, + device_weights, + device_loads, + device_counts, + redundant_expert_list, + expert_from_device, + num_experts, + rendun_pos, + ): num_devices = len(device_assignments) - com_between_devices: list[dict[int, - int]] = [{} for _ in range(num_devices)] + com_between_devices: list[dict[int, int]] = [{} for _ in range(num_devices)] for expert_id, weight in redundant_expert_list: candidate = -1 @@ -313,8 +297,7 @@ class SwiftBalanceEplb(EplbPolicy): continue if expert_id in device_assignments[dev_id]: continue - if candidate == -1 or device_loads[dev_id] < device_loads[ - candidate]: + if candidate == -1 or device_loads[dev_id] < device_loads[candidate]: candidate = dev_id if candidate != -1: pos = rendun_pos[candidate].pop() @@ -324,31 +307,42 @@ class SwiftBalanceEplb(EplbPolicy): device_counts[candidate] += 1 communication_box_index = expert_from_device[expert_id] - com_between_devices[candidate][ - communication_box_index] = expert_id + com_between_devices[candidate][communication_box_index] = expert_id if any(sublist for sublist in rendun_pos): - cur_layer_workload, num_exist_expert = self.recomputing_initial_weight( - layer_workloads, device_assignments) + cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(layer_workloads, device_assignments) - update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments( - cur_layer_workload, rendun_pos, num_experts, num_exist_expert, - device_assignments, device_loads, expert_from_device, - com_between_devices) + update_workload, device_assignments, device_counts, com_between_devices = ( + self.repeat_compute_redundant_assignments( + cur_layer_workload, + rendun_pos, + num_experts, + num_exist_expert, + device_assignments, + device_loads, + expert_from_device, + com_between_devices, + ) + ) device_loads = [0] * len(device_counts) for device_id, device in enumerate(device_assignments): for index, expert_id in enumerate(device): - device_weights[device_id][index] = update_workload[ - expert_id] + device_weights[device_id][index] = update_workload[expert_id] device_loads[device_id] += update_workload[expert_id] return device_assignments, device_weights, device_loads, device_counts, com_between_devices - def redundancy_again(self, layer_workloads, origin_weights, - origin_deployment, expert_from_device, num_node, - is_node_redundant, rendun_pos): - + def redundancy_again( + self, + layer_workloads, + origin_weights, + origin_deployment, + expert_from_device, + num_node, + is_node_redundant, + rendun_pos, + ): num_experts = len(origin_weights) if is_node_redundant: num_experts = num_experts * num_node @@ -358,25 +352,33 @@ class SwiftBalanceEplb(EplbPolicy): num_redundant_experts += len(rank_empty_pos) redundant_assignments, updated_weights = self.compute_redundant_assignments( - origin_weights, num_redundant_experts, num_experts) + origin_weights, num_redundant_experts, num_experts + ) - redundant_expert_list = self.prepare_expert_list( - updated_weights, redundant_assignments, num_redundant_experts) + redundant_expert_list = self.prepare_expert_list(updated_weights, redundant_assignments, num_redundant_experts) device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information( - origin_deployment, updated_weights, rendun_pos) + origin_deployment, updated_weights, rendun_pos + ) - device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts( - layer_workloads, device_assignments, device_weights, device_loads, - device_counts, redundant_expert_list, expert_from_device, - num_experts, rendun_pos) + device_assignments, device_weights, device_loads, device_counts, com_between_devices = ( + self.distribute_redun_experts( + layer_workloads, + device_assignments, + device_weights, + device_loads, + device_counts, + redundant_expert_list, + expert_from_device, + num_experts, + rendun_pos, + ) + ) return device_assignments, device_weights, device_loads, device_counts, com_between_devices @staticmethod - def generate_allocation_report(device_assignments, device_weights, - device_loads, device_counts): - + def generate_allocation_report(device_assignments, device_weights, device_loads, device_counts): report = [] max_load = 0.0 @@ -384,27 +386,27 @@ class SwiftBalanceEplb(EplbPolicy): current_load = device_loads[dev_id] max_load = max(max_load, current_load) - report.append({ - "device_id": dev_id + 1, - "assigned_experts": device_assignments[dev_id], - "expert_weights": device_weights[dev_id], - "total_load": current_load, - "expert_count": device_counts[dev_id] - }) + report.append( + { + "device_id": dev_id + 1, + "assigned_experts": device_assignments[dev_id], + "expert_weights": device_weights[dev_id], + "total_load": current_load, + "expert_count": device_counts[dev_id], + } + ) return report, max_load @staticmethod - def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id, - next_device_id, cur_layer_result, com_between_devices): + def exchange_expert( + cur_exchange_index, next_exchange_index, cur_device_id, next_device_id, cur_layer_result, com_between_devices + ): + cur_device_deployment = cur_layer_result[cur_device_id]["assigned_experts"] + next_device_deployment = cur_layer_result[next_device_id]["assigned_experts"] - cur_device_deployment = cur_layer_result[cur_device_id][ - 'assigned_experts'] - next_device_deployment = cur_layer_result[next_device_id][ - 'assigned_experts'] - - cur_device_weight = cur_layer_result[cur_device_id]['expert_weights'] - next_device_weight = cur_layer_result[next_device_id]['expert_weights'] + cur_device_weight = cur_layer_result[cur_device_id]["expert_weights"] + next_device_weight = cur_layer_result[next_device_id]["expert_weights"] cur_expert_id = cur_device_deployment[cur_exchange_index] next_expert_id = next_device_deployment[next_exchange_index] @@ -416,29 +418,25 @@ class SwiftBalanceEplb(EplbPolicy): cur_device_weight[cur_exchange_index] = next_expert_weight next_device_weight[next_exchange_index] = cur_expert_weight - cur_layer_result[cur_device_id][ - 'total_load'] += next_expert_weight - cur_expert_weight - cur_layer_result[next_device_id][ - 'total_load'] += cur_expert_weight - next_expert_weight + cur_layer_result[cur_device_id]["total_load"] += next_expert_weight - cur_expert_weight + cur_layer_result[next_device_id]["total_load"] += cur_expert_weight - next_expert_weight com_between_devices[cur_device_id][next_device_id] = next_expert_id com_between_devices[next_device_id][cur_device_id] = cur_expert_id - def redundant_expert_deployment(self, layer_workloads, original_deployment, - expert_from_device, node_num, - is_node_redundant, rendun_pos): + def redundant_expert_deployment( + self, layer_workloads, original_deployment, expert_from_device, node_num, is_node_redundant, rendun_pos + ): device_num, per_device_expert_num = original_deployment.shape route_expert_num = layer_workloads.shape[0] per_node_device_num = self.safe_exact_divide(device_num, node_num) - per_node_route_expert_num = per_node_device_num * ( - per_device_expert_num - 1) + per_node_route_expert_num = per_node_device_num * (per_device_expert_num - 1) - weights = np.zeros((route_expert_num, ), dtype='object') + weights = np.zeros((route_expert_num,), dtype="object") for expert_id, workload_weight in enumerate(layer_workloads): weights[expert_id] = (expert_id, workload_weight) if is_node_redundant: - device_assignments = [] device_weights = [] device_loads = [] @@ -446,23 +444,30 @@ class SwiftBalanceEplb(EplbPolicy): com_between_devices = [] for node_id in range(node_num): - cur_node_weights = weights[node_id * - per_node_route_expert_num:(node_id + - 1) * - per_node_route_expert_num] + cur_node_weights = weights[ + node_id * per_node_route_expert_num : (node_id + 1) * per_node_route_expert_num + ] cur_original_deployment = original_deployment[ - node_id * per_node_device_num:(node_id + 1) * - per_node_device_num] + node_id * per_node_device_num : (node_id + 1) * per_node_device_num + ] - cur_node_rendun_pos = rendun_pos[node_id * - per_node_device_num:(node_id + - 1) * - per_node_device_num] + cur_node_rendun_pos = rendun_pos[node_id * per_node_device_num : (node_id + 1) * per_node_device_num] - cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again( - layer_workloads, cur_node_weights, cur_original_deployment, - expert_from_device, node_num, is_node_redundant, - cur_node_rendun_pos) + ( + cur_device_assignments, + cur_device_weights, + cur_device_loads, + cur_device_counts, + cur_com_between_devices, + ) = self.redundancy_again( + layer_workloads, + cur_node_weights, + cur_original_deployment, + expert_from_device, + node_num, + is_node_redundant, + cur_node_rendun_pos, + ) device_assignments += cur_device_assignments device_weights += cur_device_weights device_loads += cur_device_loads @@ -470,28 +475,41 @@ class SwiftBalanceEplb(EplbPolicy): com_between_devices += cur_com_between_devices else: - device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again( - layer_workloads, weights, original_deployment, - expert_from_device, node_num, is_node_redundant, rendun_pos) + device_assignments, device_weights, device_loads, device_counts, com_between_devices = ( + self.redundancy_again( + layer_workloads, + weights, + original_deployment, + expert_from_device, + node_num, + is_node_redundant, + rendun_pos, + ) + ) report, max_load = self.generate_allocation_report( - device_assignments, device_weights, device_loads, device_counts) + device_assignments, device_weights, device_loads, device_counts + ) return report, max_load, com_between_devices @staticmethod - def two_device_exchange_experts(cur_device_result, exchange_device_result, - cur_exchanged_expert_id, - next_exchanged_expert_id, ave_workload, - increment, num_redundancy_expert): + def two_device_exchange_experts( + cur_device_result, + exchange_device_result, + cur_exchanged_expert_id, + next_exchanged_expert_id, + ave_workload, + increment, + num_redundancy_expert, + ): + cur_device_weight = cur_device_result["expert_weights"] + next_device_weight = exchange_device_result["expert_weights"] - cur_device_weight = cur_device_result['expert_weights'] - next_device_weight = exchange_device_result['expert_weights'] + cur_device_expert_id = cur_device_result["assigned_experts"] + next_device_expert_id = exchange_device_result["assigned_experts"] - cur_device_expert_id = cur_device_result['assigned_experts'] - next_device_expert_id = exchange_device_result['assigned_experts'] - - cur_device_total_weight = cur_device_result['total_load'] - next_device_total_weight = exchange_device_result['total_load'] + cur_device_total_weight = cur_device_result["total_load"] + next_device_total_weight = exchange_device_result["total_load"] max_weight = max(cur_device_total_weight, next_device_total_weight) cur_exchange_index = -1 @@ -500,50 +518,47 @@ class SwiftBalanceEplb(EplbPolicy): for index, weight in enumerate(cur_device_weight): for next_index, next_weight in enumerate(next_device_weight): change_flag = True - if (cur_device_expert_id[index] in next_device_expert_id - or next_device_expert_id[next_index] - in cur_device_expert_id): + if ( + cur_device_expert_id[index] in next_device_expert_id + or next_device_expert_id[next_index] in cur_device_expert_id + ): change_flag = False - if (cur_device_expert_id[index] not in cur_exchanged_expert_id - ) and (next_device_expert_id[next_index] - not in next_exchanged_expert_id) and change_flag: - + if ( + (cur_device_expert_id[index] not in cur_exchanged_expert_id) + and (next_device_expert_id[next_index] not in next_exchanged_expert_id) + and change_flag + ): cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight next_total_weight_after_exchange = next_device_total_weight - next_weight + weight - exchange_max_weight = max( - cur_total_weight_after_exchange, - next_total_weight_after_exchange) - if exchange_max_weight < max_weight and ( - max_weight - - exchange_max_weight) >= (ave_workload * increment): + exchange_max_weight = max(cur_total_weight_after_exchange, next_total_weight_after_exchange) + if exchange_max_weight < max_weight and (max_weight - exchange_max_weight) >= ( + ave_workload * increment + ): max_weight = exchange_max_weight cur_exchange_index = index next_exchange_index = next_index return cur_exchange_index, next_exchange_index - def expert_exchange_between_devices(self, - ave_workload, - increment, - cur_layer_result, - com_between_devices, - num_redundancy_expert, - node_idx=0, - per_node_device_num=0, - is_node_redundant=False): - + def expert_exchange_between_devices( + self, + ave_workload, + increment, + cur_layer_result, + com_between_devices, + num_redundancy_expert, + node_idx=0, + per_node_device_num=0, + is_node_redundant=False, + ): if is_node_redundant: - cur_devices_result = cur_layer_result[node_idx * - per_node_device_num: - (node_idx + 1) * - per_node_device_num] + cur_devices_result = cur_layer_result[node_idx * per_node_device_num : (node_idx + 1) * per_node_device_num] else: cur_devices_result = cur_layer_result devices_total_weight = [] for device in cur_devices_result: - devices_total_weight.append( - (device['total_load'], device['device_id'] - 1)) + devices_total_weight.append((device["total_load"], device["device_id"] - 1)) exchange_frequency = 100 while exchange_frequency > 0: @@ -553,64 +568,81 @@ class SwiftBalanceEplb(EplbPolicy): exchange = False for index in range(0, len(devices_total_weight) - 1): min_weight_device_id = devices_total_weight[index][1] - if min_weight_device_id not in com_between_devices[ - max_weight_device_id]: - cur_exchanged_expert_id = list( - com_between_devices[max_weight_device_id].values()) - next_exchanged_expert_id = list( - com_between_devices[min_weight_device_id].values()) + if min_weight_device_id not in com_between_devices[max_weight_device_id]: + cur_exchanged_expert_id = list(com_between_devices[max_weight_device_id].values()) + next_exchanged_expert_id = list(com_between_devices[min_weight_device_id].values()) cur_exchange_index, next_exchange_index = self.two_device_exchange_experts( cur_layer_result[max_weight_device_id], cur_layer_result[min_weight_device_id], - cur_exchanged_expert_id, next_exchanged_expert_id, - ave_workload, increment, num_redundancy_expert) + cur_exchanged_expert_id, + next_exchanged_expert_id, + ave_workload, + increment, + num_redundancy_expert, + ) if cur_exchange_index != -1: - self.exchange_expert(cur_exchange_index, - next_exchange_index, - max_weight_device_id, - min_weight_device_id, - cur_layer_result, - com_between_devices) + self.exchange_expert( + cur_exchange_index, + next_exchange_index, + max_weight_device_id, + min_weight_device_id, + cur_layer_result, + com_between_devices, + ) devices_total_weight[-1] = ( - cur_layer_result[max_weight_device_id] - ['total_load'], max_weight_device_id) + cur_layer_result[max_weight_device_id]["total_load"], + max_weight_device_id, + ) devices_total_weight[index] = ( - cur_layer_result[min_weight_device_id] - ['total_load'], min_weight_device_id) + cur_layer_result[min_weight_device_id]["total_load"], + min_weight_device_id, + ) exchange = True break if not exchange: break - def exchange_experts(self, layer_result, layer_com_between_devices, - num_nodes, device_num, is_node_redundant, - ave_workload, increment, num_redundancy_expert, - org_deployment): - + def exchange_experts( + self, + layer_result, + layer_com_between_devices, + num_nodes, + device_num, + is_node_redundant, + ave_workload, + increment, + num_redundancy_expert, + org_deployment, + ): global_deployment = [] if is_node_redundant: per_node_device_num = self.safe_exact_divide(device_num, num_nodes) for node_idx in range(num_nodes): self.expert_exchange_between_devices( - ave_workload, increment, layer_result, - layer_com_between_devices, num_redundancy_expert, node_idx, - per_node_device_num, is_node_redundant) + ave_workload, + increment, + layer_result, + layer_com_between_devices, + num_redundancy_expert, + node_idx, + per_node_device_num, + is_node_redundant, + ) else: - self.expert_exchange_between_devices(ave_workload, increment, - layer_result, - layer_com_between_devices, - num_redundancy_expert) + self.expert_exchange_between_devices( + ave_workload, increment, layer_result, layer_com_between_devices, num_redundancy_expert + ) max_workload = 0 for box in layer_result: - global_deployment.append(box['assigned_experts']) - if max_workload < box['total_load']: - max_workload = box['total_load'] + global_deployment.append(box["assigned_experts"]) + if max_workload < box["total_load"]: + max_workload = box["total_load"] global_deployment = np.array(global_deployment) @@ -626,16 +658,11 @@ class SwiftBalanceEplb(EplbPolicy): return count @staticmethod - def constraint_expert_local_exchange(current_expert_table, - global_deployment): + def constraint_expert_local_exchange(current_expert_table, global_deployment): for layer_id in range(len(global_deployment)): for card_id in range(len(global_deployment[layer_id])): - current_list = [ - int(x) for x in current_expert_table[layer_id][card_id] - ] - new_list = [ - int(x) for x in global_deployment[layer_id][card_id] - ] + current_list = [int(x) for x in current_expert_table[layer_id][card_id]] + new_list = [int(x) for x in global_deployment[layer_id][card_id]] num = len(new_list) new_index = [-1] * num @@ -645,8 +672,7 @@ class SwiftBalanceEplb(EplbPolicy): for i in range(num): flag = True for j in range(num): - if new_list[i] == current_list[j] and new_index[ - j] == -1: + if new_list[i] == current_list[j] and new_index[j] == -1: new_index[j] = 0 new_result[j] = current_list[j] flag = False @@ -664,25 +690,17 @@ class SwiftBalanceEplb(EplbPolicy): return global_deployment - def rebalance_experts(self, - current_expert_table, - expert_workload, - is_node_redundant=False, - increment=0.01): + def rebalance_experts(self, current_expert_table, expert_workload, is_node_redundant=False, increment=0.01): info = DynamicTable() info.workload_table = expert_workload.numpy() info.placement_table = current_expert_table.numpy() assert info.workload_table is not None layer_num, num_npus, experts_per_npu = info.workload_table.shape - expert_ids, counts = np.unique(info.placement_table[0], - return_counts=True) + expert_ids, counts = np.unique(info.placement_table[0], return_counts=True) num_redundancy_expert = self.get_redundant_num(num_npus, counts) num_original_expert = len(expert_ids) - layer_workloads = self.add_redundant(info.placement_table, - info.workload_table, - num_original_expert) - max_heat_per_layer_before = self.calculate_max_heat_per_layer( - info.workload_table, layer_num) + layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) npu_heat_all_origin = sum(max_heat_per_layer_before) num_node = self.safe_exact_divide(num_npus, 8) @@ -700,14 +718,13 @@ class SwiftBalanceEplb(EplbPolicy): if num_npus < num_redundancy_expert: raise ValueError( - f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})" + "The number of NPUs " + f"({num_npus}) must be greater than or equal to the number of redundant experts " + f"({num_redundancy_expert})" ) - global_deployment: list[list[list[int]]] = [[[] - for _ in range(num_npus)] - for _ in range(layer_num)] - layer_initial_imbalance = self.calculate_initial_imbalance( - info.placement_table, layer_workloads) + global_deployment: list[list[list[int]]] = [[[] for _ in range(num_npus)] for _ in range(layer_num)] + layer_initial_imbalance = self.calculate_initial_imbalance(info.placement_table, layer_workloads) max_heat_per_layer_after = np.zeros([layer_num]) sum_num = 0 for layer in range(layer_num): @@ -715,8 +732,7 @@ class SwiftBalanceEplb(EplbPolicy): global_deployment[layer] = info.placement_table[layer] continue - ave_workload = self.safe_divide(np.sum(layer_workloads[layer]), - num_npus) + ave_workload = self.safe_divide(np.sum(layer_workloads[layer]), num_npus) rendun_pos: list[list[int]] = [[] for _ in range(num_npus)] existing_experts = set() @@ -729,30 +745,37 @@ class SwiftBalanceEplb(EplbPolicy): rendun_pos[device_id].append(index) result, max_workload, com_between_devices = self.redundant_expert_deployment( - layer_workloads[layer], info.placement_table[layer], - expert_from_device[layer], num_node, is_node_redundant, - rendun_pos) + layer_workloads[layer], + info.placement_table[layer], + expert_from_device[layer], + num_node, + is_node_redundant, + rendun_pos, + ) global_deployment[layer], new_max_workload = self.exchange_experts( - result, com_between_devices, num_node, num_npus, - is_node_redundant, ave_workload, increment, - num_redundancy_expert, info.placement_table[layer]) + result, + com_between_devices, + num_node, + num_npus, + is_node_redundant, + ave_workload, + increment, + num_redundancy_expert, + info.placement_table[layer], + ) for device_id in range(num_npus): - com_between_devices[device_id] = { - key: value - for key, value in com_between_devices[device_id].items() - } + com_between_devices[device_id] = {key: value for key, value in com_between_devices[device_id].items()} sum_num += self.count_elements(com_between_devices[device_id]) - max_heat_per_layer_after[layer] = max( - result, key=lambda x: x['total_load'])['total_load'] + max_heat_per_layer_after[layer] = max(result, key=lambda x: x["total_load"])["total_load"] layer_changed_ratio = [] for layer_idx in range(layer_num): layer_changed_ratio.append( - self.safe_divide(max_heat_per_layer_after[layer_idx], - max_heat_per_layer_before[layer_idx])) + self.safe_divide(max_heat_per_layer_after[layer_idx], max_heat_per_layer_before[layer_idx]) + ) per_layer_priority = np.argsort(layer_changed_ratio) npu_heat_all_after = sum(max_heat_per_layer_after) @@ -761,8 +784,6 @@ class SwiftBalanceEplb(EplbPolicy): if npu_heat_all_after < 0.95 * npu_heat_all_origin: change = 1 - new_global_deployment = self.constraint_expert_local_exchange( - current_expert_table, global_deployment) + new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment) - return change, per_layer_priority, np.array( - new_global_deployment).tolist() + return change, per_layer_priority, np.array(new_global_deployment).tolist() diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index cf7cece4..b9cd66a8 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -26,9 +26,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess class EplbUpdator: - - def __init__(self, eplb_config, loader, eplb_process: EplbProcess, - process): + def __init__(self, eplb_config, loader, eplb_process: EplbProcess, process): self.eplb_config = eplb_config self.init_eplb(self.eplb_config.expert_map_path, process) self.eplb_loader = loader @@ -43,9 +41,7 @@ class EplbUpdator: self.world_size = dist.get_world_size() self.device = local_load.device shape = (self.world_size, *local_load.shape) - self._gather_buffer = torch.empty(shape, - dtype=local_load.dtype, - device=self.device) + self._gather_buffer = torch.empty(shape, dtype=local_load.dtype, device=self.device) def init_eplb(self, expert_map_path, process): self.rank_id = dist.get_rank() @@ -72,52 +68,49 @@ class EplbUpdator: self.process = process - logger.info( - f"[ModelRunner] Launched EPLB process (pid={self.process.pid})") + logger.info(f"[ModelRunner] Launched EPLB process (pid={self.process.pid})") def update_iteration(self): self.cur_iterations += 1 - if self.cur_iterations == (self.expert_heat_collection_interval + \ - self.algorithm_execution_interval + self.num_moe_layers): + if self.cur_iterations == ( + self.expert_heat_collection_interval + self.algorithm_execution_interval + self.num_moe_layers + ): logger.info("Finish expert parallel load balancing.") if self.expert_map_record_path is not None: - self.adaptor._export_tensor_to_file( - self.shared_dict["expert_maps"], - self.expert_map_record_path) + self.adaptor._export_tensor_to_file(self.shared_dict["expert_maps"], self.expert_map_record_path) self.adaptor.model.clear_all_moe_loads() self.cur_iterations = 0 def get_update_info_flag(self): - return self.cur_iterations == (self.expert_heat_collection_interval + - self.algorithm_execution_interval - 1) + return self.cur_iterations == (self.expert_heat_collection_interval + self.algorithm_execution_interval - 1) def wakeup_eplb_worker_flag(self): - return self.cur_iterations == (self.expert_heat_collection_interval - - 1) + return self.cur_iterations == (self.expert_heat_collection_interval - 1) def update_expert_weight_flag(self): weight_update_counter = self.cur_iterations - ( - self.expert_heat_collection_interval + - self.algorithm_execution_interval) - return (weight_update_counter >= 0 - and weight_update_counter < self.num_moe_layers) + self.expert_heat_collection_interval + self.algorithm_execution_interval + ) + return weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers def wakeup_eplb_worker(self): self.eplb_process.planner_q.put(1) def forward_before(self): if self.update_expert_weight_flag(): - (expert_send_info, expert_recv_info, updated_expert_map, - log2phy_map, layer_id) = self.update_info_all.pop(0) + (expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop( + 0 + ) log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map)) self.eplb_loader.set_log2phy_map(log2phy_map_this_rank) - updated_expert_map_this_rank = torch.from_numpy( - numpy.array(updated_expert_map)) + updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map)) self.eplb_loader.generate_expert_d2d_transfer_task( - expert_send_info, expert_recv_info, + expert_send_info, + expert_recv_info, updated_expert_map_this_rank, - layer_id + self.adaptor.num_dense_layers) + layer_id + self.adaptor.num_dense_layers, + ) # set asynchronous stream for d2d expert weight update self.reqs = [] @@ -133,8 +126,7 @@ class EplbUpdator: self.compute_and_set_moe_load() self.wakeup_eplb_worker() - if self.update_expert_weight_flag( - ) and self.expert_map_record_path is None: + if self.update_expert_weight_flag() and self.expert_map_record_path is None: self.eplb_loader.update_expert_map_and_weight(self.reqs) self.update_iteration() @@ -145,9 +137,7 @@ class EplbUpdator: moe_load = self._gather_buffer.permute(1, 0, 2) self.shared_dict["moe_load"] = moe_load.cpu() - logger.debug( - f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" - ) + logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") if dist.get_rank() == 0: self.compute_moe_imbalance(moe_load) @@ -156,7 +146,6 @@ class EplbUpdator: return moe_load def compute_moe_imbalance(self, moe_load: torch.Tensor): - self.moe_imbalance_dict.clear() layer_card_load = moe_load.sum(dim=-1).cpu().float() @@ -169,13 +158,11 @@ class EplbUpdator: moe_load_imbalance = max_load / (mean_load + 1e-6) - logger.debug(f"[ModelRunner][MOE_load_stats][Layer {layer_idx}] " - f"PAR={moe_load_imbalance:.4f}") + logger.debug(f"[ModelRunner][MOE_load_stats][Layer {layer_idx}] PAR={moe_load_imbalance:.4f}") self.moe_imbalance_dict[layer_idx] = moe_load_imbalance def summarize_moe_imbalance(self): - values = list(self.moe_imbalance_dict.values()) if not values: logger.info("[MOE_load_stats] No data available.") @@ -191,11 +178,10 @@ class EplbUpdator: ) def warm_up_eplb(self): - self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map() self.compute_and_set_moe_load() - src_tensor = torch.empty((1, ), device=self.device) + src_tensor = torch.empty((1,), device=self.device) self_rank = dist.get_rank() comm_op_list = [] diff --git a/vllm_ascend/eplb/utils.py b/vllm_ascend/eplb/utils.py index 0efa623b..a6a577b3 100644 --- a/vllm_ascend/eplb/utils.py +++ b/vllm_ascend/eplb/utils.py @@ -30,33 +30,30 @@ def get_log2phy_map(self, layer_id): def get_all_expert_map(self, num_moe_layers): all_loads = [] - num_dense_layers = self.num_dense_layers if hasattr( - self, "num_dense_layers") else 0 + num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0 for layer_id in range(num_moe_layers): - load_tensor = self.get_expert_map( - layer_id + num_dense_layers) # (num_experts_per_layer,) + load_tensor = self.get_expert_map(layer_id + num_dense_layers) # (num_experts_per_layer,) all_loads.append(load_tensor) return torch.stack(all_loads, dim=0) def get_all_moe_loads(self): - num_dense_layers = self.num_dense_layers if hasattr( - self, "num_dense_layers") else 0 + num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0 all_moe_loads = torch.stack( - [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ - for layer_id in range(self.num_moe_layers)], - dim=0 + [ + self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load + for layer_id in range(self.num_moe_layers) + ], + dim=0, ) return all_moe_loads def clear_all_moe_loads(self): - num_dense_layers = self.num_dense_layers if hasattr( - self, "num_dense_layers") else 0 + num_dense_layers = self.num_dense_layers if hasattr(self, "num_dense_layers") else 0 for layer_id in range(self.num_moe_layers): - self.model.layers[layer_id + - num_dense_layers].mlp.experts.clear_moe_load() + self.model.layers[layer_id + num_dense_layers].mlp.experts.clear_moe_load() def model_register(model, model_config): diff --git a/vllm_ascend/model_loader/netloader/executor/elastic_load.py b/vllm_ascend/model_loader/netloader/executor/elastic_load.py index 476116de..8432f059 100644 --- a/vllm_ascend/model_loader/netloader/executor/elastic_load.py +++ b/vllm_ascend/model_loader/netloader/executor/elastic_load.py @@ -18,8 +18,7 @@ import torch import torch_npu from vllm.logger import logger -from .netloader_pg import (destroy_stateless_process_group, - stateless_init_process_group) +from .netloader_pg import destroy_stateless_process_group, stateless_init_process_group class P2PLoad: @@ -56,9 +55,7 @@ class P2PLoad: - The model if loading is successful, otherwise None. """ model_device = next(model.parameters()).device - logger.info( - f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" - ) + logger.info(f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}") receiver_pg = None loaded_model = None try: @@ -67,15 +64,13 @@ class P2PLoad: port=self.source_port, rank=0, world_size=2, - group_name='netloader', + group_name="netloader", ) logger.info( f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" ) - logger.info( - f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" - ) + logger.info(f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}") logger.info(f"Model device: {model_device}") trans_stream = torch_npu.npu.Stream() @@ -84,14 +79,11 @@ class P2PLoad: if len(param.shape) == 0: continue receiver_pg.recv([param], 1, 0).wait() - torch.distributed.barrier(group=receiver_pg, - device_ids=[model_device.index]) + torch.distributed.barrier(group=receiver_pg, device_ids=[model_device.index]) torch_npu.npu.synchronize(trans_stream) - logger.info( - f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" - ) + logger.info(f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}") loaded_model = model except Exception as e: logger.error("Failed to recv model: {}".format(e)) @@ -129,9 +121,7 @@ class P2PSend: """ model_device = next(model.parameters()).device torch.npu.set_device(model_device) - logger.info( - f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}" - ) + logger.info(f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}") sender_pg = None try: sender_pg = stateless_init_process_group( @@ -139,14 +129,10 @@ class P2PSend: port=self.listen_port, rank=1, world_size=2, - group_name='netloader', - ) - logger.info( - f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}" - ) - logger.info( - f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}" + group_name="netloader", ) + logger.info(f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}") + logger.info(f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}") logger.info(f"Model device: {model_device}") trans_stream = torch_npu.npu.Stream() @@ -155,16 +141,12 @@ class P2PSend: if "aclnn_input_scale" in name: continue if name in int8_params: - sender_pg.send([int8_params[name].to(model_device)], 0, - 0).wait() + sender_pg.send([int8_params[name].to(model_device)], 0, 0).wait() else: sender_pg.send([param.contiguous()], 0, 0).wait() - torch.distributed.barrier(group=sender_pg, - device_ids=[model_device.index]) + torch.distributed.barrier(group=sender_pg, device_ids=[model_device.index]) torch_npu.npu.synchronize(trans_stream) - logger.info( - f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}" - ) + logger.info(f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}") finally: if sender_pg: - destroy_stateless_process_group(sender_pg) \ No newline at end of file + destroy_stateless_process_group(sender_pg) diff --git a/vllm_ascend/model_loader/netloader/executor/netloader_pg.py b/vllm_ascend/model_loader/netloader/executor/netloader_pg.py index 13018a50..1a6ce330 100644 --- a/vllm_ascend/model_loader/netloader/executor/netloader_pg.py +++ b/vllm_ascend/model_loader/netloader/executor/netloader_pg.py @@ -17,16 +17,13 @@ import gc import ipaddress from datetime import timedelta -from typing import Any, Optional +from typing import Any import torch import torch_npu -from torch._C._distributed_c10d import (_DEFAULT_PG_TIMEOUT, - _register_process_group, - _unregister_process_group) +from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT, _register_process_group, _unregister_process_group from torch.distributed import ProcessGroup, is_hccl_available -from torch.distributed.distributed_c10d import (Backend, BackendConfig, - PrefixStore, _world) +from torch.distributed.distributed_c10d import Backend, BackendConfig, PrefixStore, _world from torch.distributed.rendezvous import rendezvous from torch_npu._C._distributed_c10d import ProcessGroupHCCL from vllm.logger import logger @@ -39,7 +36,7 @@ def stateless_init_process_group( rank: int, timeout: timedelta = _DEFAULT_PG_TIMEOUT, group_name: str = "", - pg_options: Optional[Any] = None, + pg_options: Any | None = None, ) -> ProcessGroup: """ Initializes a stateless process group. @@ -57,7 +54,8 @@ def stateless_init_process_group( ProcessGroup: The initialized process group. Raises: - RuntimeError: If world_size is not positive, or if rank is not within [0, world_size - 1], or if HCCL is unavailable. + RuntimeError: If world_size is not positive, or if rank is not within + [0, world_size - 1], or if HCCL is unavailable. TypeError: If timeout is not a timedelta type. ValueError: If group_name already exists. """ @@ -67,21 +65,18 @@ def stateless_init_process_group( raise RuntimeError("world_size must be positive") # Check if rank is within [0, world_size - 1] if not (rank >= 0 and rank <= world_size - 1): - raise RuntimeError( - "rank should be a number between 0 and ``world_size``-1") + raise RuntimeError("rank should be a number between 0 and ``world_size``-1") # Check if HCCL is available if not is_hccl_available(): raise RuntimeError("HCCL is not available") # Check if timeout is a timedelta type if not isinstance(timeout, timedelta): - raise TypeError( - f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" - ) + raise TypeError(f"Expected timeout argument to be of type datetime.timedelta, got {timeout}") # Check if group_name already exists if group_name in _world.pg_names.values(): raise ValueError( - f"The specified group name {group_name} has already been " - "created, please use a different group name") + f"The specified group name {group_name} has already been created, please use a different group name" + ) # Function to check if an IPv6 address is valid def is_valid_ipv6_address(address: str) -> bool: @@ -101,10 +96,9 @@ def stateless_init_process_group( # Get initialization method init_method = get_tcp_uri(host, port) # Create Backend object - backend = Backend('hccl') + backend = Backend("hccl") # Use rendezvous function to get store, rank, and world_size - store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout)) + store, rank, world_size = next(rendezvous(init_method, rank, world_size, timeout=timeout)) # Set timeout for store store.set_timeout(timeout) @@ -125,9 +119,7 @@ def stateless_init_process_group( pg._set_default_backend(Backend.backend_type_map[backend]) # Check if pg_options is None or not of type ProcessGroupHCCL.Options - if pg_options is None or not isinstance( - pg_options, - torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options): + if pg_options is None or not isinstance(pg_options, torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options): pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options() # Set attributes for pg_options pg_options.is_high_priority_stream = False @@ -135,8 +127,7 @@ def stateless_init_process_group( pg_options.global_ranks_in_group = [] pg_options.group_id = f"{init_method}/{group_name}/" # Create ProcessGroupHCCL object - backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, - pg_options) + backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, pg_options) # Set sequence number for backend_class backend_class._set_sequence_number_for_group() # Set backend_type @@ -176,9 +167,10 @@ def destroy_stateless_process_group(pg: ProcessGroup, manual_gc: bool = False): _world.pg_group_ranks.pop(pg, None) _world.pg_backend_config.pop(pg, None) # Check if pg is in keys of _world.pg_coalesce_state - if pg in _world.pg_coalesce_state.keys(): - logger.warning("Some coalesced collectives haven't been launched when " - "ProcessGroup is destroyed. They will be cleaned.") + if pg in _world.pg_coalesce_state: + logger.warning( + "Some coalesced collectives haven't been launched when ProcessGroup is destroyed. They will be cleaned." + ) del _world.pg_coalesce_state[pg] # Unregister the process group _unregister_process_group(pg.group_name) diff --git a/vllm_ascend/model_loader/netloader/interaction/elastic.py b/vllm_ascend/model_loader/netloader/interaction/elastic.py index 61b2ad3b..7c1c7f8d 100644 --- a/vllm_ascend/model_loader/netloader/interaction/elastic.py +++ b/vllm_ascend/model_loader/netloader/interaction/elastic.py @@ -18,7 +18,7 @@ import json import re import socket import threading -from typing import List, Optional, Tuple +from contextlib import suppress import torch from vllm.logger import logger @@ -32,8 +32,7 @@ class ElasticClient: Class for handling the client-side logic of Netloader of models. """ - def __init__(self, sources: list[str], device_id: int, model_path: str, - tp: int, pp: int): + def __init__(self, sources: list[str], device_id: int, model_path: str, tp: int, pp: int): """ Initializes the ElasticClient instance. @@ -50,14 +49,14 @@ class ElasticClient: self.tp = tp self.pp = pp - self.s: Optional[socket.socket] = None - self.ack: Optional[Tuple[str, int]] = None - self.server_addr: Optional[str] = None - self.server_port: Optional[int] = None + self.s: socket.socket | None = None + self.ack: tuple[str, int] | None = None + self.server_addr: str | None = None + self.server_port: int | None = None for source in self.sources: try: - ip, port_str = source.split(':') + ip, port_str = source.split(":") port = int(port_str) except Exception as e: logger.info(f"IP format error: {source}, detail: {e}") @@ -68,13 +67,9 @@ class ElasticClient: try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - logger.info( - f"Start connection to server: {self.server_addr}:{self.server_port}" - ) + logger.info(f"Start connection to server: {self.server_addr}:{self.server_port}") sock.connect((self.server_addr, self.server_port)) - logger.info( - f"Finish connection to server: {self.server_addr}:{self.server_port}" - ) + logger.info(f"Finish connection to server: {self.server_addr}:{self.server_port}") sock.settimeout(60) self.s = sock @@ -83,10 +78,8 @@ class ElasticClient: except Exception as e: logger.error(f"Connect to {source} fails, detail: {e}") if sock is not None: - try: + with suppress(Exception): sock.close() - except Exception: - pass self.s = None self.ack = None self.server_addr = None @@ -120,10 +113,8 @@ class ElasticClient: """ Destructor method to ensure socket is closed. """ - try: + with suppress(Exception): self.close() - except Exception: - pass def send_str(self, data_str: str) -> None: """ @@ -151,8 +142,7 @@ class ElasticClient: data_str = self.s.recv(buffer_size).decode("utf-8") return data_str - def register(self, device_id: int, model_path: str, tp: int, - pp: int) -> Tuple[str, int]: + def register(self, device_id: int, model_path: str, tp: int, pp: int) -> tuple[str, int]: """ Registers the client with the server. @@ -168,20 +158,13 @@ class ElasticClient: free_port = find_free_port() data = { "label": "JOIN", - "content": { - 'device_id': device_id, - 'model_path': model_path, - 'tp': tp, - 'pp': pp, - 'port': free_port - } + "content": {"device_id": device_id, "model_path": model_path, "tp": tp, "pp": pp, "port": free_port}, } try: self.send_str(json.dumps(data)) except Exception as e: - raise RuntimeError( - f"Send data {data} to server fails, detail: {e}") + raise RuntimeError(f"Send data {data} to server fails, detail: {e}") try: ack_str = self.recv_str() @@ -191,23 +174,22 @@ class ElasticClient: try: ack = json.loads(ack_str) except Exception as e: - raise RuntimeError( - f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}" - ) + raise RuntimeError(f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}") logger.info(f"Receive ack: {ack}") - if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack - and ack["content"] is not None and "name" in ack["content"]): + if ( + "label" in ack + and ack["label"] == "JOIN_ACK" + and "content" in ack + and ack["content"] is not None + and "name" in ack["content"] + ): return (ack["content"]["name"], free_port) - elif ("label" in ack and ack["label"] == 'JOIN_NACK' - and "content" in ack): - raise RuntimeError( - f"Receive nack from server, reason: {ack['content']}") + elif "label" in ack and ack["label"] == "JOIN_NACK" and "content" in ack: + raise RuntimeError(f"Receive nack from server, reason: {ack['content']}") else: - raise RuntimeError( - f"Receive ack {ack} from server does not contain required fields" - ) + raise RuntimeError(f"Receive ack {ack} from server does not contain required fields") class ElasticServer: @@ -215,9 +197,18 @@ class ElasticServer: Class for handling the server-side logic of Netloader of models. """ - def __init__(self, addr: str, port: int, model, device_id: int, - model_path: str, tp: int, pp: int, int8_cache: str, - int8_cache_name: Optional[List[str]]): + def __init__( + self, + addr: str, + port: int, + model, + device_id: int, + model_path: str, + tp: int, + pp: int, + int8_cache: str, + int8_cache_name: list[str] | None, + ): """ Initializes the ElasticServer instance. @@ -246,30 +237,25 @@ class ElasticServer: self.pp = pp self.original_int8 = {} - int8_pattern = "|".join( - map(re.escape, - int8_cache_name)) if int8_cache_name is not None else "(?:)" + int8_pattern = "|".join(map(re.escape, int8_cache_name)) if int8_cache_name is not None else "(?:)" for name, param in self.model.named_parameters(): if param.dtype == torch.int8: - if int8_cache == 'hbm': + if int8_cache == "hbm": if int8_cache_name is None or ( - int8_cache_name is not None - and re.search(int8_pattern, name) is not None): + int8_cache_name is not None and re.search(int8_pattern, name) is not None + ): try: - self.original_int8[name] = param.data.clone( - ).detach() + self.original_int8[name] = param.data.clone().detach() except RuntimeError as e: - logger.error( - f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}" - ) + logger.error(f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}") self.original_int8[name] = param.data.cpu() - elif int8_cache == 'dram': + elif int8_cache == "dram": if int8_cache_name is None or ( - int8_cache_name is not None - and re.search(int8_pattern, name) is not None): + int8_cache_name is not None and re.search(int8_pattern, name) is not None + ): self.original_int8[name] = param.data.cpu() - elif int8_cache == 'no': + elif int8_cache == "no": pass else: logger.warning( @@ -277,14 +263,18 @@ class ElasticServer: ) logger.info( - f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}" + f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, " + f"model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, " + f"int8 params {list(self.original_int8)} are saved to {int8_cache}" ) def __del__(self): """ Destructor method to ensure socket is closed. """ - self.s.close() + if self.s is not None: + with suppress(Exception): + self.s.close() def start(self): """ @@ -343,10 +333,7 @@ class ElasticServer: if not all(k in content for k in required_keys): return False port = content["port"] - if not (isinstance(port, int) or - (isinstance(port, str) and port.isdigit())): - return False - return True + return isinstance(port, int) or (isinstance(port, str) and port.isdigit()) comm_name = None if is_valid_data(data): @@ -355,36 +342,31 @@ class ElasticServer: tp = int(data["content"]["tp"]) pp = int(data["content"]["pp"]) - if int(self.device_id - ) == device_id and self.model_path == model_path and int( - self.tp) == tp and int(self.pp) == pp: + if ( + int(self.device_id) == device_id + and self.model_path == model_path + and int(self.tp) == tp + and int(self.pp) == pp + ): comm_name = str(addr[0]) + ":" + str(addr[1]) ack = {"label": "JOIN_ACK", "content": {"name": comm_name}} else: - logger.warning( - f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}" - ) + server_desc = (int(self.device_id), self.model_path, int(self.tp), int(self.pp)) + client_desc = (device_id, model_path, tp, pp) + msg = f"Received data {client_desc} does not consist with this server {server_desc}" + logger.warning(msg) ack = { - "label": - "JOIN_NACK", - "content": - f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}" + "label": "JOIN_NACK", + "content": msg, } else: - logger.warning( - f"Received data does not contain required fields: {data}") - ack = { - "label": - "JOIN_NACK", - "content": - f"Received data does not contain required fields: {data}" - } + logger.warning(f"Received data does not contain required fields: {data}") + ack = {"label": "JOIN_NACK", "content": f"Received data does not contain required fields: {data}"} try: ack_str = json.dumps(ack).encode("utf-8") except Exception as e: - logger.error( - f"Failed to convert {ack} to JSON format, details: {e}") + logger.error(f"Failed to convert {ack} to JSON format, details: {e}") conn.close() return @@ -395,14 +377,10 @@ class ElasticServer: conn.close() return - if ack["content"] and isinstance(ack["content"], - dict) and 'name' in ack["content"]: + if ack["content"] and isinstance(ack["content"], dict) and "name" in ack["content"]: try: - p2psend = P2PSend(self.addr, data["content"]["port"], - ack["content"]["name"]) + p2psend = P2PSend(self.addr, data["content"]["port"], ack["content"]["name"]) p2psend.send(self.model, self.original_int8) except Exception as e: - logger.error( - f"P2PSend Failed to send model to {self.addr}, details: {e}" - ) + logger.error(f"P2PSend Failed to send model to {self.addr}, details: {e}") conn.close() diff --git a/vllm_ascend/model_loader/netloader/load.py b/vllm_ascend/model_loader/netloader/load.py index 4dd24107..eafa97e5 100644 --- a/vllm_ascend/model_loader/netloader/load.py +++ b/vllm_ascend/model_loader/netloader/load.py @@ -48,36 +48,27 @@ def elastic_load( # Filter sources for the current device sources_this_device = [] for s in sources: - if isinstance( - s, dict - ) and "device_id" in s and s["device_id"] == device_id and isinstance( - s["sources"], list): + if isinstance(s, dict) and "device_id" in s and s["device_id"] == device_id and isinstance(s["sources"], list): sources_this_device += s["sources"] if len(sources_this_device) == 0: return None try: # Initialize the interaction layer with the ElasticClient - with ElasticClient(sources_this_device, device_id, model_path, tp, - pp) as client_interaction_layer: + with ElasticClient(sources_this_device, device_id, model_path, tp, pp) as client_interaction_layer: if client_interaction_layer.s is None or client_interaction_layer.server_addr is None: - raise RuntimeError( - "Failed to initialize ElasticClient: socket or server_addr is None" - ) + raise RuntimeError("Failed to initialize ElasticClient: socket or server_addr is None") ack = client_interaction_layer.ack if ack is None: raise RuntimeError("ElasticClient.register did not return ack") t0 = time.perf_counter() - elastic_loader = P2PLoad(ack[0], - client_interaction_layer.server_addr, - ack[1]) + elastic_loader = P2PLoad(ack[0], client_interaction_layer.server_addr, ack[1]) model_loaded = elastic_loader.load(model=model) if model_loaded is None: logger.error("Failed to load model") return None - logger.info("Finish elastic load (duration: {}s)".format( - time.perf_counter() - t0)) + logger.info("Finish elastic load (duration: {}s)".format(time.perf_counter() - t0)) return model_loaded except Exception as e: logger.info(f"elastic_load error: {e}") diff --git a/vllm_ascend/model_loader/netloader/netloader.py b/vllm_ascend/model_loader/netloader/netloader.py index 2968ee36..cc22ea33 100644 --- a/vllm_ascend/model_loader/netloader/netloader.py +++ b/vllm_ascend/model_loader/netloader/netloader.py @@ -18,7 +18,6 @@ import gc import json import time from copy import deepcopy -from typing import List, Optional, Tuple import torch from torch import nn @@ -27,8 +26,7 @@ from vllm.logger import logger from vllm.model_executor.model_loader import register_model_loader from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.default_loader import DefaultModelLoader -from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading) +from vllm.model_executor.model_loader.utils import initialize_model, process_weights_after_loading from vllm.utils.torch_utils import set_default_torch_dtype from .interaction.elastic import ElasticServer @@ -41,12 +39,13 @@ class ModelNetLoaderElastic(BaseModelLoader): """ A model loader that uses elastic loading for loading weights. """ - source: Optional[List[dict]] - model_path: Optional[str] - listen_port: Optional[int] + + source: list[dict] | None + model_path: str | None + listen_port: int | None int8_cache: str - int8_cache_name: Optional[List[str]] - output_prefix: Optional[str] + int8_cache_name: list[str] | None + output_prefix: str | None def __init__(self, load_config: LoadConfig): """ @@ -63,18 +62,15 @@ class ModelNetLoaderElastic(BaseModelLoader): extra = load_config.model_loader_extra_config if extra and "CONFIG_FILE" in extra: try: - logger.info( - f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ..." - ) - with open(extra["CONFIG_FILE"], 'r') as f: + logger.info(f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ...") + with open(extra["CONFIG_FILE"]) as f: config = json.load(f) except FileNotFoundError: logger.error("CONFIG_FILE not found") except json.JSONDecodeError: logger.error("CONFIG_FILE is not a valid JSON file") except Exception as e: - logger.error( - f"Unexpected error while reading CONFIG_FILE: {e}") + logger.error(f"Unexpected error while reading CONFIG_FILE: {e}") if config is None and extra: logger.info("Reading configs in model_loader_extra_config ...") @@ -82,19 +78,30 @@ class ModelNetLoaderElastic(BaseModelLoader): config = config or {} for key, attr, checker, caster, default in [ - ("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v, - None), - ("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v, - None), - ("LISTEN_PORT", "listen_port", lambda v: isinstance(v, int) or - (isinstance(v, str) and v.isdigit()), lambda v: int(v), None), - ("INT8_CACHE", "int8_cache", lambda v: isinstance(v, str) and v. - lower() in ['hbm', 'dram', 'no'], lambda v: v.lower(), 'no'), - ("INT8_CACHE_NAME", "int8_cache_name", - lambda v: isinstance(v, list), lambda v: v, None), - ("OUTPUT_PREFIX", "output_prefix", - lambda v: isinstance(v, str) and is_valid_path_prefix(v), - lambda v: v, None), + ("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v, None), + ("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v, None), + ( + "LISTEN_PORT", + "listen_port", + lambda v: isinstance(v, int) or (isinstance(v, str) and v.isdigit()), + lambda v: int(v), + None, + ), + ( + "INT8_CACHE", + "int8_cache", + lambda v: isinstance(v, str) and v.lower() in ["hbm", "dram", "no"], + lambda v: v.lower(), + "no", + ), + ("INT8_CACHE_NAME", "int8_cache_name", lambda v: isinstance(v, list), lambda v: v, None), + ( + "OUTPUT_PREFIX", + "output_prefix", + lambda v: isinstance(v, str) and is_valid_path_prefix(v), + lambda v: v, + None, + ), ]: v = config.get(key, default) if not checker(v): @@ -116,8 +123,7 @@ class ModelNetLoaderElastic(BaseModelLoader): self.output_prefix, ) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> nn.Module: """ Loads the model using the specified configuration. @@ -140,15 +146,18 @@ class ModelNetLoaderElastic(BaseModelLoader): device_id = torch.distributed.get_rank() - if (self.source is None or not isinstance(self.source, list) - or device_id not in [ - one_device["device_id"] for one_device in self.source if - isinstance(one_device, dict) and "device_id" in one_device - ]): - logger.warning( - "Did not get valid source info, use DefaultModelLoader") - model, need_process_weights_after_loading = self.revert_to_default( - model_config, vllm_config, device_config) + if ( + self.source is None + or not isinstance(self.source, list) + or device_id + not in [ + one_device["device_id"] + for one_device in self.source + if isinstance(one_device, dict) and "device_id" in one_device + ] + ): + logger.warning("Did not get valid source info, use DefaultModelLoader") + model, need_process_weights_after_loading = self.revert_to_default(model_config, vllm_config, device_config) else: target_device = torch.device(device_config.device) @@ -158,8 +167,7 @@ class ModelNetLoaderElastic(BaseModelLoader): with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config, - model_config=model_config) + model = initialize_model(vllm_config=vllm_config, model_config=model_config) start_elastic_load = time.perf_counter() model = elastic_load( @@ -171,43 +179,39 @@ class ModelNetLoaderElastic(BaseModelLoader): pp=parallel_config.pipeline_parallel_size, ) end_elastic_load = time.perf_counter() - logger.info( - f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}" - ) + logger.info(f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}") need_process_weights_after_loading = True if model is None: - logger.warning( - "Netloader elastic loading fails, use load format DefaultModelLoader" - ) + logger.warning("Netloader elastic loading fails, use load format DefaultModelLoader") vllm_config = vllm_config_backup model_config = model_config_backup del model gc.collect() - if device_config.device_type == 'npu': + if device_config.device_type == "npu": logger.info("Empty NPU cache") torch.npu.empty_cache() - elif device_config.device_type == 'cuda': + elif device_config.device_type == "cuda": logger.info("Empty CUDA cache") torch.cuda.empty_cache() model, need_process_weights_after_loading = self.revert_to_default( - model_config, vllm_config, device_config) + model_config, vllm_config, device_config + ) start_elastic_server = time.perf_counter() # start elastic server if model is not None and ( - (self.listen_port and self.listen_port in range(1024, 65535)) or - (self.listen_port is None)): - + (self.listen_port and self.listen_port in range(1024, 65535)) or (self.listen_port is None) + ): from vllm.utils.network_utils import get_ip + driver_ip = get_ip() - if driver_ip == '0.0.0.0': - logger.error( - "Driver IP is not set, skip to start Netloader server") + if driver_ip == "0.0.0.0": + logger.error("Driver IP is not set, skip to start Netloader server") else: if self.listen_port is None: self.listen_port = find_free_port() @@ -220,21 +224,14 @@ class ModelNetLoaderElastic(BaseModelLoader): if self.output_prefix is not None: try: - with open(self.output_prefix + str(device_id) + '.txt', - 'w') as file: + with open(self.output_prefix + str(device_id) + ".txt", "w") as file: file.write(f"{driver_ip}:{self.listen_port}") - logger.info( - f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}" - ) + logger.info(f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}") except FileNotFoundError: - logger.error( - f"File path {self.output_prefix + str(device_id)} does not exist." - ) + logger.error(f"File path {self.output_prefix + str(device_id)} does not exist.") except PermissionError: - logger.error( - f"No permission to write to file {self.output_prefix + str(device_id)}." - ) - except IOError as e: + logger.error(f"No permission to write to file {self.output_prefix + str(device_id)}.") + except OSError as e: logger.error( f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}" ) @@ -242,31 +239,30 @@ class ModelNetLoaderElastic(BaseModelLoader): logger.error(f"Unknown error: {e}") try: - assert isinstance( - self.listen_port, int - ), f"listen port should be int but get {self.listen_port}" + assert isinstance(self.listen_port, int), f"listen port should be int but get {self.listen_port}" elastic_server = ElasticServer( - driver_ip, self.listen_port, model, device_id, - self.model_path, parallel_config.tensor_parallel_size, + driver_ip, + self.listen_port, + model, + device_id, + self.model_path, + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, - self.int8_cache, self.int8_cache_name) + self.int8_cache, + self.int8_cache_name, + ) elastic_server.start() except Exception as e: - logger.error( - f"Failed to start Netloader server for rank: {device_id}, details: {e}" - ) + logger.error(f"Failed to start Netloader server for rank: {device_id}, details: {e}") else: logger.info("Skip to start Netloader server") end_elastic_server = time.perf_counter() - logger.info( - f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}" - ) + logger.info(f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}") if need_process_weights_after_loading: - process_weights_after_loading(model, model_config, - torch.device(device_config.device)) + process_weights_after_loading(model, model_config, torch.device(device_config.device)) if model is None: logger.error("NetLoader elastic loads model fails") @@ -274,8 +270,7 @@ class ModelNetLoaderElastic(BaseModelLoader): return model.eval() - def revert_to_default(self, model_config, vllm_config, - device_config) -> Tuple[nn.Module, bool]: + def revert_to_default(self, model_config, vllm_config, device_config) -> tuple[nn.Module, bool]: """ Reverts to the default model loading logic when elastic loading fails or is not applicable. @@ -300,19 +295,15 @@ class ModelNetLoaderElastic(BaseModelLoader): default_model_loader = DefaultModelLoader(self.load_config) if model_config.quantization is None: - model = default_model_loader.load_model(vllm_config=vllm_config, - model_config=model_config) + model = default_model_loader.load_model(vllm_config=vllm_config, model_config=model_config) need_process_weights_after_loading = False else: - logger.warning( - "Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading " - ) + logger.warning("Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading ") need_process_weights_after_loading = True target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config, - model_config=model_config) + model = initialize_model(vllm_config=vllm_config, model_config=model_config) default_model_loader.load_weights(model, model_config) model = model.eval() @@ -321,6 +312,5 @@ class ModelNetLoaderElastic(BaseModelLoader): def download_model(self, model_config: ModelConfig) -> None: pass - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: pass diff --git a/vllm_ascend/model_loader/netloader/utils.py b/vllm_ascend/model_loader/netloader/utils.py index fba5a58c..1481ed16 100644 --- a/vllm_ascend/model_loader/netloader/utils.py +++ b/vllm_ascend/model_loader/netloader/utils.py @@ -29,7 +29,7 @@ def find_free_port(): - A free port number. """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] @@ -47,20 +47,17 @@ def is_valid_path_prefix(path_prefix): return False if re.search(r'[<>:"|?*]', path_prefix): - logger.warning( - f'The path prefix {path_prefix} contains illegal characters.') + logger.warning(f"The path prefix {path_prefix} contains illegal characters.") return False - if path_prefix.startswith('/') or path_prefix.startswith('\\'): + if path_prefix.startswith("/") or path_prefix.startswith("\\"): if not os.path.exists(os.path.dirname(path_prefix)): - logger.warning( - f'The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.' - ) + logger.warning(f"The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.") return False else: if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))): logger.warning( - f'The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist.' + f"The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist." ) return False - return True \ No newline at end of file + return True diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index c215c059..5fa56ce8 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -23,9 +23,8 @@ import vllm_ascend.patch.platform.patch_sched_yield # noqa from vllm_ascend import envs from vllm_ascend.utils import vllm_version_is -if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv( - "EXPERT_MAP_RECORD", "false") == "true": +if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true": import vllm_ascend.patch.platform.patch_multiproc_executor # noqa -if envs.VLLM_ASCEND_BALANCE_SCHEDULING and vllm_version_is('0.14.0'): +if envs.VLLM_ASCEND_BALANCE_SCHEDULING and vllm_version_is("0.14.0"): import vllm_ascend.patch.platform.patch_balance_schedule # noqa diff --git a/vllm_ascend/patch/platform/patch_balance_schedule.py b/vllm_ascend/patch/platform/patch_balance_schedule.py index ae840e46..9a34eb37 100644 --- a/vllm_ascend/patch/platform/patch_balance_schedule.py +++ b/vllm_ascend/patch/platform/patch_balance_schedule.py @@ -7,17 +7,14 @@ import torch.distributed as dist import vllm from vllm.config import ParallelConfig from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata -from vllm.distributed.kv_transfer.kv_connector.v1.base import \ - KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.transformers_utils.config import \ - maybe_register_config_serialize_by_value +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc @@ -30,7 +27,6 @@ logger = init_logger(__name__) class BalanceScheduler(Scheduler): - def __init__( self, vllm_config, @@ -41,9 +37,15 @@ class BalanceScheduler(Scheduler): include_finished_set: bool = False, log_stats: bool = False, ) -> None: - super().__init__(vllm_config, kv_cache_config, - structured_output_manager, block_size, mm_registry, - include_finished_set, log_stats) + super().__init__( + vllm_config, + kv_cache_config, + structured_output_manager, + block_size, + mm_registry, + include_finished_set, + log_stats, + ) # Balance scheduling. self.balance_queue = [ torch.tensor([0], dtype=torch.int, device="cpu") @@ -51,9 +53,7 @@ class BalanceScheduler(Scheduler): ] def balance_gather(self, dp_group): - running_tensor = torch.tensor([len(self.running)], - dtype=torch.int, - device="cpu") + running_tensor = torch.tensor([len(self.running)], dtype=torch.int, device="cpu") dist.all_gather(self.balance_queue, running_tensor, group=dp_group) def schedule(self) -> SchedulerOutput: @@ -89,33 +89,32 @@ class BalanceScheduler(Scheduler): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if (request.num_output_placeholders > 0 - # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). - # Since output placeholders are also included in the computed tokens - # count, we subtract (num_output_placeholders - 1) to remove any draft - # tokens, so that we can be sure no further steps are needed even if - # they are all rejected. - and request.num_computed_tokens + 2 - - request.num_output_placeholders - >= request.num_prompt_tokens + request.max_tokens): + if ( + request.num_output_placeholders > 0 + # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). + # Since output placeholders are also included in the computed tokens + # count, we subtract (num_output_placeholders - 1) to remove any draft + # tokens, so that we can be sure no further steps are needed even if + # they are all rejected. + and request.num_computed_tokens + 2 - request.num_output_placeholders + >= request.num_prompt_tokens + request.max_tokens + ): # Async scheduling: Avoid scheduling an extra step when we are sure that # the previous step has reached request.max_tokens. We don't schedule # partial draft tokens since this prevents uniform decode optimizations. req_index += 1 continue - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) + num_new_tokens = ( + request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens + ) if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. - num_new_tokens = min( - num_new_tokens, - self.max_model_len - 1 - request.num_computed_tokens) + num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens) # Schedule encoder inputs. encoder_inputs_to_schedule = None @@ -174,20 +173,17 @@ class BalanceScheduler(Scheduler): self.running.remove(preempted_req) if preempted_req in scheduled_running_reqs: scheduled_running_reqs.remove(preempted_req) - token_budget += num_scheduled_tokens[ - preempted_req.request_id] + token_budget += num_scheduled_tokens[preempted_req.request_id] req_to_new_blocks.pop(preempted_req.request_id) num_scheduled_tokens.pop(preempted_req.request_id) - scheduled_spec_decode_tokens.pop( - preempted_req.request_id, None) - preempted_encoder_inputs = scheduled_encoder_inputs.pop( - preempted_req.request_id, None) + scheduled_spec_decode_tokens.pop(preempted_req.request_id, None) + preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None) if preempted_encoder_inputs: # Restore encoder compute budget if the preempted # request had encoder inputs scheduled in this step. num_embeds_to_restore = sum( - preempted_req.get_num_encoder_embeds(i) - for i in preempted_encoder_inputs) + preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs + ) encoder_compute_budget += num_embeds_to_restore req_index -= 1 else: @@ -212,23 +208,20 @@ class BalanceScheduler(Scheduler): # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens - - request.num_output_placeholders) + num_scheduled_spec_tokens = ( + num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders + ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] - scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids # New spec tokens will be set in `update_draft_token_ids` before the # next step when applicable. request.spec_token_ids = [] # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -243,8 +236,10 @@ class BalanceScheduler(Scheduler): scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -257,9 +252,7 @@ class BalanceScheduler(Scheduler): if len(self.running) == self.max_num_running_reqs: break - balance_flag = (max( - t.item() - for t in self.balance_queue) == self.max_num_running_reqs) + balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs if balance_flag: break @@ -292,9 +285,14 @@ class BalanceScheduler(Scheduler): # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -306,14 +304,15 @@ class BalanceScheduler(Scheduler): # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = ( - self.kv_cache_manager.get_computed_blocks(request)) + new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks( + request + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: - ext_tokens, load_kv_async = ( - self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens + ) if ext_tokens is None: # The request cannot be scheduled because @@ -327,8 +326,7 @@ class BalanceScheduler(Scheduler): num_external_computed_tokens = ext_tokens # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) + num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -356,8 +354,7 @@ class BalanceScheduler(Scheduler): # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if (not self.scheduler_config.enable_chunked_prefill - and num_new_tokens > token_budget): + if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget: # If chunked_prefill is disabled, # we can stop the scheduling here. break @@ -388,9 +385,7 @@ class BalanceScheduler(Scheduler): # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens # Determine if we need to allocate cross-attention blocks. if self.is_encoder_decoder and request.has_encoder_inputs: @@ -398,8 +393,7 @@ class BalanceScheduler(Scheduler): # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens = ( - self.scheduler_config.max_num_encoder_input_tokens) + num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens else: num_encoder_tokens = 0 @@ -442,20 +436,17 @@ class BalanceScheduler(Scheduler): self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -465,8 +456,7 @@ class BalanceScheduler(Scheduler): request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -476,8 +466,7 @@ class BalanceScheduler(Scheduler): for i in external_load_encoder_input: self.encoder_cache_manager.allocate(request, i) if self.ec_connector is not None: - self.ec_connector.update_state_after_alloc( - request, i) + self.ec_connector.update_state_after_alloc(request, i) # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) @@ -491,20 +480,15 @@ class BalanceScheduler(Scheduler): # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( - scheduled_running_reqs) <= len(self.running) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) - with record_function_or_nullcontext( - "schedule: get_num_common_prefix_blocks"): + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): if self.running: any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) + num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) # Construct the scheduler output. if self.use_v2_model_runner: @@ -515,17 +499,16 @@ class BalanceScheduler(Scheduler): req, req_to_new_blocks[req.request_id].get_block_ids(), req._all_token_ids, - ) for req in scheduled_new_reqs + ) + for req in scheduled_new_reqs ] else: new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) + NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids()) for req in scheduled_new_reqs ] - with record_function_or_nullcontext( - "schedule: make_cached_request_data"): + with record_function_or_nullcontext("schedule: make_cached_request_data"): cached_reqs_data = self._make_cached_request_data( scheduled_running_reqs, scheduled_resumed_reqs, @@ -546,15 +529,13 @@ class BalanceScheduler(Scheduler): scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, - preempted_req_ids={req.request_id - for req in preempted_reqs}, + preempted_req_ids={req.request_id for req in preempted_reqs}, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), ) # NOTE(Kuntai): this function is designed for multiple purposes: @@ -562,14 +543,12 @@ class BalanceScheduler(Scheduler): # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta: KVConnectorMetadata = self.connector.build_connector_meta( - scheduler_output) + meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta # Build the connector meta for ECConnector if self.ec_connector is not None: - ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( - scheduler_output) + ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output) scheduler_output.ec_connector_metadata = ec_meta with record_function_or_nullcontext("schedule: update_after_schedule"): @@ -578,7 +557,6 @@ class BalanceScheduler(Scheduler): class BalanceDPEngineCoreProc(DPEngineCoreProc): - def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -602,23 +580,23 @@ class BalanceDPEngineCoreProc(DPEngineCoreProc): self.execute_dummy_batch() # 3) All-reduce operation to determine global unfinished reqs. - self.engines_running = self._has_global_unfinished_reqs( - local_unfinished_reqs) + self.engines_running = self._has_global_unfinished_reqs(local_unfinished_reqs) self.scheduler.balance_gather(self.dp_group) if not self.engines_running: if self.dp_rank == 0 or not self.has_coordinator: # Notify client that we are pausing the loop. - logger.debug("Wave %d finished, pausing engine loop.", - self.current_wave) + logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) # In the coordinator case, dp rank 0 sends updates to the # coordinator. Otherwise (offline spmd case), each rank # sends the update to its colocated front-end process. client_index = -1 if self.has_coordinator else 0 - self.output_queue.put_nowait(( - client_index, - EngineCoreOutputs(wave_complete=self.current_wave), - )) + self.output_queue.put_nowait( + ( + client_index, + EngineCoreOutputs(wave_complete=self.current_wave), + ) + ) # Increment wave count and reset step counter. self.current_wave += 1 self.step_counter = 0 diff --git a/vllm_ascend/patch/platform/patch_ec_connector.py b/vllm_ascend/patch/platform/patch_ec_connector.py index 61ca8535..f7666b74 100644 --- a/vllm_ascend/patch/platform/patch_ec_connector.py +++ b/vllm_ascend/patch/platform/patch_ec_connector.py @@ -1,21 +1,21 @@ import vllm.distributed.ec_transfer.ec_connector.example_connector from safetensors.torch import load_file -from vllm.distributed.ec_transfer.ec_connector.example_connector import ( - ECConnectorMetadata, ECExampleConnector) +from vllm.distributed.ec_transfer.ec_connector.example_connector import ECConnectorMetadata, ECExampleConnector from vllm.logger import logger class AscendECExampleConnector(ECExampleConnector): - def start_load_caches(self, encoder_cache, **kwargs) -> None: metadata: ECConnectorMetadata = self._get_connector_metadata() assert isinstance(metadata, ECConnectorMetadata) assert encoder_cache is not None if metadata is None: - logger.warning(( - "In connector.start_load_caches, ", - "but the connector metadata is None", - )) + logger.warning( + ( + "In connector.start_load_caches, ", + "but the connector metadata is None", + ) + ) return # Load the EC for each mm data for mm_data in metadata.mm_datas: @@ -24,8 +24,7 @@ class AscendECExampleConnector(ECExampleConnector): filename = self._generate_filename_debug(mm_data.mm_hash) ec_cache = load_file(filename)["ec_cache"].npu() encoder_cache[mm_data.mm_hash] = ec_cache - logger.debug("Success load encoder cache for hash %s", - mm_data.mm_hash) + logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash) vllm.distributed.ec_transfer.ec_connector.example_connector.ECExampleConnector = AscendECExampleConnector diff --git a/vllm_ascend/patch/platform/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_mamba_config.py index 18939b0f..beb7cec7 100644 --- a/vllm_ascend/patch/platform/patch_mamba_config.py +++ b/vllm_ascend/patch/platform/patch_mamba_config.py @@ -38,7 +38,8 @@ def verify_and_update_config(cls, vllm_config) -> None: block_size=1, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), - dtype=kv_cache_dtype).page_size_bytes + dtype=kv_cache_dtype, + ).page_size_bytes model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, @@ -58,23 +59,20 @@ def verify_and_update_config(cls, vllm_config) -> None: # block size to multiple of 16, so let's suggest a value # that would work (note: FA is currently not compatible # with mamba layers, use FlashInfer instead). - attn_block_size = block_alignment_bytes * cdiv( - mamba_page_size, block_alignment_bytes * attn_page_size_1_token) + attn_block_size = block_alignment_bytes * cdiv(mamba_page_size, block_alignment_bytes * attn_page_size_1_token) # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small. - if (cache_config.block_size is None - or cache_config.block_size < attn_block_size): + if cache_config.block_size is None or cache_config.block_size < attn_block_size: cache_config.block_size = attn_block_size logger.info( - "Setting attention block size to %d tokens " - "to ensure that attention page size is >= mamba page size.", - attn_block_size) + "Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.", + attn_block_size, + ) # compute new attention page size - attn_page_size = \ - cache_config.block_size * attn_page_size_1_token + attn_page_size = cache_config.block_size * attn_page_size_1_token assert attn_page_size >= mamba_page_size @@ -83,15 +81,15 @@ def verify_and_update_config(cls, vllm_config) -> None: return # pad mamba page size to exactly match attention - if (cache_config.mamba_page_size_padded is None - or cache_config.mamba_page_size_padded != attn_page_size): - cache_config.mamba_page_size_padded = (attn_page_size) - mamba_padding_pct = 100 * (attn_page_size - - mamba_page_size) / mamba_page_size + if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size: + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size logger.info( "Padding mamba page size by %.2f%% to ensure " "that mamba page size and attention page size are " - "exactly equal.", mamba_padding_pct) + "exactly equal.", + mamba_padding_pct, + ) vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config diff --git a/vllm_ascend/patch/platform/patch_multiproc_executor.py b/vllm_ascend/patch/platform/patch_multiproc_executor.py index 11c2a433..540ad238 100644 --- a/vllm_ascend/patch/platform/patch_multiproc_executor.py +++ b/vllm_ascend/patch/platform/patch_multiproc_executor.py @@ -7,19 +7,20 @@ from multiprocessing.synchronize import Lock as LockType import vllm.v1.executor.multiproc_executor from vllm import envs from vllm.config import VllmConfig -from vllm.distributed.device_communicators.shm_broadcast import (Handle, - MessageQueue) -from vllm.utils.network_utils import (get_distributed_init_method, - get_loopback_ip, get_open_port) +from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue +from vllm.utils.network_utils import get_distributed_init_method, get_loopback_ip, get_open_port from vllm.utils.system_utils import get_mp_context from vllm.v1.executor.abstract import FailureCallback from vllm.v1.executor.multiproc_executor import ( - FutureWrapper, MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc, - set_multiprocessing_worker_envs) + FutureWrapper, + MultiprocExecutor, + UnreadyWorkerProcHandle, + WorkerProc, + set_multiprocessing_worker_envs, +) class AscendMultiprocExecutor(MultiprocExecutor): - def _init_executor(self) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. @@ -32,7 +33,8 @@ class AscendMultiprocExecutor(MultiprocExecutor): assert self.world_size % self.parallel_config.nnodes_within_dp == 0, ( f"global world_size ({self.parallel_config.world_size}) must be " f"divisible by nnodes_within_dp " - f"({self.parallel_config.nnodes_within_dp}). ") + f"({self.parallel_config.nnodes_within_dp}). " + ) self.local_world_size = self.parallel_config.local_world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size @@ -41,7 +43,8 @@ class AscendMultiprocExecutor(MultiprocExecutor): f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"_parallel_size ({pp_parallel_size}) x prefill_context" - f"_parallel_size ({pcp_parallel_size}). ") + f"_parallel_size ({pcp_parallel_size}). " + ) # Set multiprocessing envs set_multiprocessing_worker_envs() @@ -49,8 +52,7 @@ class AscendMultiprocExecutor(MultiprocExecutor): # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # get_loopback_ip() for communication. - distributed_init_method = get_distributed_init_method( - get_loopback_ip(), get_open_port()) + distributed_init_method = get_distributed_init_method(get_loopback_ip(), get_open_port()) self.rpc_broadcast_mq: MessageQueue | None = None scheduler_output_handle: Handle | None = None # Initialize worker and set up message queues for SchedulerOutputs @@ -72,8 +74,7 @@ class AscendMultiprocExecutor(MultiprocExecutor): unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: - global_start_rank = (self.local_world_size * - self.parallel_config.node_rank_within_dp) + global_start_rank = self.local_world_size * self.parallel_config.node_rank_within_dp for local_rank in range(self.local_world_size): global_rank = global_start_rank + local_rank unready_workers.append( @@ -84,7 +85,8 @@ class AscendMultiprocExecutor(MultiprocExecutor): distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, shared_worker_lock=shared_worker_lock, - )) + ) + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. @@ -101,13 +103,11 @@ class AscendMultiprocExecutor(MultiprocExecutor): if self.parallel_config.node_rank_within_dp == 0: for rank in range(self.world_size): if rank < self.local_world_size: - local_message_queue = self.workers[ - rank].worker_response_mq + local_message_queue = self.workers[rank].worker_response_mq assert local_message_queue is not None self.response_mqs.append(local_message_queue) else: - remote_message_queue = self.workers[ - 0].peer_worker_response_mqs[rank] + remote_message_queue = self.workers[0].peer_worker_response_mqs[rank] assert remote_message_queue is not None self.response_mqs.append(remote_message_queue) @@ -128,8 +128,7 @@ class AscendMultiprocExecutor(MultiprocExecutor): for uw in unready_workers: if uw.death_writer is not None: uw.death_writer.close() - self._ensure_worker_termination( - [uw.proc for uw in unready_workers]) + self._ensure_worker_termination([uw.proc for uw in unready_workers]) self.futures_queue = deque[tuple[FutureWrapper, Callable]]() @@ -137,7 +136,6 @@ class AscendMultiprocExecutor(MultiprocExecutor): class AscendWorkerProc(WorkerProc): - @staticmethod def make_worker_process( vllm_config: VllmConfig, diff --git a/vllm_ascend/patch/platform/patch_sched_yield.py b/vllm_ascend/patch/platform/patch_sched_yield.py index 694b9577..6aafd030 100644 --- a/vllm_ascend/patch/platform/patch_sched_yield.py +++ b/vllm_ascend/patch/platform/patch_sched_yield.py @@ -3,11 +3,10 @@ import sys import vllm.distributed.utils from vllm.platforms import CpuArchEnum, Platform -is_arm = (Platform.get_cpu_architecture() == CpuArchEnum.ARM) +is_arm = Platform.get_cpu_architecture() == CpuArchEnum.ARM USE_SCHED_YIELD = ( - ((sys.version_info[:3] >= (3, 11, 1)) or - (sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8)) - and not is_arm) + (sys.version_info[:3] >= (3, 11, 1)) or (sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8) +) and not is_arm vllm.distributed.utils.USE_SCHED_YIELD = USE_SCHED_YIELD