[EPLB][Bugfix] policy_swift_balancer bugfix and renaming (#5897)
### What this PR does / why we need it?
1. Rename dynamic_ep to default_eplb.
2. Rename dynamic_ep_v2 to swift_balancer
3. Discard func compose_expert_update_info_bipartite.
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
@@ -25,8 +25,8 @@ vllm_ascend
|
||||
│ ├── core
|
||||
│ │ ├── policy
|
||||
│ │ │ ├── policy_abstract.py
|
||||
│ │ │ ├── policy_dynamic_ep.py
|
||||
│ │ │ ├── policy_dynamic_ep_v2.py
|
||||
│ │ │ ├── policy_default_eplb.py
|
||||
│ │ │ ├── policy_swift_balancer.py
|
||||
│ │ │ ├── policy_factory.py
|
||||
│ │ │ ├── policy_flashlb.py
|
||||
│ │ ├── eplb_device_transfer_loader.py
|
||||
@@ -52,9 +52,9 @@ vllm_ascend
|
||||
*Load balancing algorithms with factory pattern instantiation*
|
||||
- `policy_abstract.py`
|
||||
Abstract class for load balancing strategy interfaces
|
||||
- `policy_dynamic_ep.py`
|
||||
- `policy_default_eplb.py`
|
||||
Default implementation of open-source EPLB paper algorithm
|
||||
- `policy_dynamic_ep_v2.py`
|
||||
- `policy_swift_balancer.py`
|
||||
Enhanced version optimizing expert swaps for low-bandwidth devices (e.g., A2)
|
||||
- `policy_flashlb.py`
|
||||
Threshold-based adjustment reducing operational costs through layer-wise fluctuation detection
|
||||
|
||||
@@ -3,16 +3,16 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.eplb.core.policy.policy_dynamic_ep import DynamicEplb
|
||||
from vllm_ascend.eplb.core.policy.policy_default_eplb import DefaultEplb
|
||||
|
||||
|
||||
class TestDynamicEplb:
|
||||
class TestDefaultEplb:
|
||||
|
||||
def test_add_redundant_basic(self):
|
||||
current_expert_table = np.array([[[0, 1], [1, 0]]])
|
||||
expert_workload = np.array([[[2, 3], [4, 1]]])
|
||||
num_original_expert = 2
|
||||
result = DynamicEplb.add_redundant(current_expert_table,
|
||||
result = DefaultEplb.add_redundant(current_expert_table,
|
||||
expert_workload,
|
||||
num_original_expert)
|
||||
expected = np.array([[2 + 1, 3 + 4]])
|
||||
@@ -20,51 +20,51 @@ class TestDynamicEplb:
|
||||
|
||||
def test_get_redundant_num(self):
|
||||
counts = np.array([2, 1, 3])
|
||||
assert DynamicEplb.get_redundant_num(3, counts) == 3
|
||||
assert DefaultEplb.get_redundant_num(3, counts) == 3
|
||||
|
||||
def test_calculate_max_heat_per_layer(self):
|
||||
workload_table = np.array([[[1, 2], [3, 4]], [[2, 2], [1, 1]]])
|
||||
max_heat = DynamicEplb.calculate_max_heat_per_layer(workload_table, 2)
|
||||
max_heat = DefaultEplb.calculate_max_heat_per_layer(workload_table, 2)
|
||||
assert max_heat == [7, 4]
|
||||
|
||||
def test_constraint_expert_local_exchange(self):
|
||||
current = [[[0, 1], [2, 3]]]
|
||||
global_dep = [[[1, 0], [3, 2]]]
|
||||
new_dep = DynamicEplb.constraint_expert_local_exchange(
|
||||
new_dep = DefaultEplb.constraint_expert_local_exchange(
|
||||
current, global_dep)
|
||||
assert new_dep == [[[0, 1], [2, 3]]]
|
||||
|
||||
def test_compute_balanced_pack_redundancy_normal(self):
|
||||
origin_weights = [(0, 10), (1, 20)]
|
||||
result, boxes = DynamicEplb.compute_balanced_pack_redundancy(
|
||||
result, boxes = DefaultEplb.compute_balanced_pack_redundancy(
|
||||
origin_weights, 2, 1)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
|
||||
def test_compute_balanced_pack_redundancy_card0(self):
|
||||
origin_weights = [(0, 10)]
|
||||
with pytest.raises(RuntimeError):
|
||||
DynamicEplb.compute_balanced_pack_redundancy(origin_weights, 0, 0)
|
||||
DefaultEplb.compute_balanced_pack_redundancy(origin_weights, 0, 0)
|
||||
|
||||
def test_compute_balanced_pack_normal(self):
|
||||
origin_weights = np.array([(0, 10), (1, 20)], dtype=object)
|
||||
result, boxes = DynamicEplb.compute_balanced_pack(origin_weights, 2)
|
||||
result, boxes = DefaultEplb.compute_balanced_pack(origin_weights, 2)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
|
||||
def test_compute_balanced_pack_card0(self):
|
||||
origin_weights = np.array([(0, 10)], dtype=object)
|
||||
with pytest.raises(RuntimeError):
|
||||
DynamicEplb.compute_balanced_pack(origin_weights, 0)
|
||||
DefaultEplb.compute_balanced_pack(origin_weights, 0)
|
||||
|
||||
def test_original_compute_balanced_pack_redundancy(self):
|
||||
origin_weights = [(0, 5), (1, 10)]
|
||||
result, boxes = DynamicEplb.original_compute_balanced_pack_redundancy(
|
||||
result, boxes = DefaultEplb.original_compute_balanced_pack_redundancy(
|
||||
origin_weights, 2, 1)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
|
||||
def test_rebalance_experts_normal(self):
|
||||
expert_table = np.array([[[0, 1], [1, 0]]])
|
||||
workload = np.array([[[2, 3], [4, 1]]])
|
||||
policy = DynamicEplb(config=None)
|
||||
policy = DefaultEplb(config=None)
|
||||
change, priority, new_dep = policy.rebalance_experts(
|
||||
expert_table, workload)
|
||||
assert change in [0, 1]
|
||||
@@ -73,12 +73,12 @@ class TestDynamicEplb:
|
||||
assert np.array(new_dep).shape == expert_table.shape
|
||||
|
||||
def test_rebalance_experts_exceptions(self):
|
||||
policy = DynamicEplb(config=None)
|
||||
policy = DefaultEplb(config=None)
|
||||
|
||||
# case1: num_original_expert != expert_num
|
||||
expert_table = np.array([[[0, 1], [1, 0]]])
|
||||
workload = np.array([[[2, 3], [4, 1]]])
|
||||
with patch.object(DynamicEplb,
|
||||
with patch.object(DefaultEplb,
|
||||
'add_redundant',
|
||||
return_value=np.array([[1, 2, 3]])):
|
||||
with pytest.raises(ValueError):
|
||||
@@ -93,6 +93,6 @@ class TestDynamicEplb:
|
||||
# case3: num_npus < num_redundancy_expert
|
||||
expert_table_small = np.array([[[0, 0]]]) # 1 layer, 1 NPU, 2 experts
|
||||
workload_small = np.array([[[1, 1]]])
|
||||
with patch.object(DynamicEplb, 'get_redundant_num', return_value=2):
|
||||
with patch.object(DefaultEplb, 'get_redundant_num', return_value=2):
|
||||
with pytest.raises(ValueError):
|
||||
policy.rebalance_experts(expert_table_small, workload_small)
|
||||
@@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.eplb.core.policy.policy_abstract import DynamicConfig
|
||||
from vllm_ascend.eplb.core.policy.policy_dynamic_ep import DynamicEplb
|
||||
from vllm_ascend.eplb.core.policy.policy_dynamic_ep_v2 import DynamicEplbV2
|
||||
from vllm_ascend.eplb.core.policy.policy_default_eplb import DefaultEplb
|
||||
from vllm_ascend.eplb.core.policy.policy_swift_balancer import SwiftBalanceEplb
|
||||
from vllm_ascend.eplb.core.policy.policy_factory import PolicyFactory
|
||||
from vllm_ascend.eplb.core.policy.policy_random import RandomLoadBalance
|
||||
|
||||
@@ -14,8 +14,8 @@ def dummy_config():
|
||||
|
||||
@pytest.mark.parametrize("policy_type, expected_class", [
|
||||
(0, RandomLoadBalance),
|
||||
(1, DynamicEplb),
|
||||
(2, DynamicEplbV2),
|
||||
(1, DefaultEplb),
|
||||
(2, SwiftBalanceEplb),
|
||||
(999, RandomLoadBalance),
|
||||
])
|
||||
def test_generate_policy(policy_type, expected_class, dummy_config):
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Dict, Set
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.eplb.core.policy.policy_dynamic_ep_v2 import (DynamicConfig,
|
||||
DynamicEplbV2)
|
||||
from vllm_ascend.eplb.core.policy.policy_swift_balancer import (DynamicConfig,
|
||||
SwiftBalanceEplb)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -14,7 +14,7 @@ def config():
|
||||
|
||||
@pytest.fixture
|
||||
def policy(config):
|
||||
return DynamicEplbV2(config)
|
||||
return SwiftBalanceEplb(config)
|
||||
|
||||
|
||||
def test_safe_operations(policy):
|
||||
@@ -34,19 +34,19 @@ def test_safe_operations(policy):
|
||||
def test_add_redundant():
|
||||
workload = np.array([[[1, 2], [3, 4]]])
|
||||
placement = np.array([[[0, 1], [0, 1]]])
|
||||
result = DynamicEplbV2.add_redundant(placement, workload, 2)
|
||||
result = SwiftBalanceEplb.add_redundant(placement, workload, 2)
|
||||
assert result.shape == (1, 2)
|
||||
assert np.all(result[0] == [4, 6]) # 0:1+3, 1:2+4
|
||||
|
||||
|
||||
def test_get_redundant_num():
|
||||
counts = np.array([1, 2, 1])
|
||||
assert DynamicEplbV2.get_redundant_num(3, counts) == 1 # sum(counts-1)
|
||||
assert SwiftBalanceEplb.get_redundant_num(3, counts) == 1 # sum(counts-1)
|
||||
|
||||
|
||||
def test_calculate_max_heat_per_layer():
|
||||
workload = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
||||
result = DynamicEplbV2.calculate_max_heat_per_layer(workload, 2)
|
||||
result = SwiftBalanceEplb.calculate_max_heat_per_layer(workload, 2)
|
||||
assert result == [7, 15]
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_compute_redundant_assignments(policy):
|
||||
def test_prepare_expert_list():
|
||||
base_experts = [(0, 10), (1, 5)]
|
||||
redundant_assignments = [[2], []]
|
||||
result = DynamicEplbV2.prepare_expert_list(base_experts,
|
||||
result = SwiftBalanceEplb.prepare_expert_list(base_experts,
|
||||
redundant_assignments, 1)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
@@ -79,7 +79,7 @@ def test_non_redundant_expert_information():
|
||||
origin_deployment = np.array([[0, 1]])
|
||||
updated_weights = [(0, 10), (1, 5)]
|
||||
rendun_pos: Dict[int, Set[int]] = {0: set()}
|
||||
assignments, weights, loads, counts = DynamicEplbV2.non_redundant_expert_information(
|
||||
assignments, weights, loads, counts = SwiftBalanceEplb.non_redundant_expert_information(
|
||||
origin_deployment, updated_weights, rendun_pos)
|
||||
assert assignments[0] == [0, 1]
|
||||
assert loads[0] == 15
|
||||
@@ -73,12 +73,8 @@ class EplbWorker:
|
||||
new_expert_maps = self.local2global(new_placement)
|
||||
self.update_expert_map(new_expert_maps)
|
||||
|
||||
if self.policy_type == 2:
|
||||
update_info = self.compose_expert_update_info_bipartite(
|
||||
new_expert_maps, self.old_expert_maps)
|
||||
else:
|
||||
update_info = self.compose_expert_update_info_greedy(
|
||||
new_expert_maps, self.old_expert_maps)
|
||||
update_info = self.compose_expert_update_info_greedy(
|
||||
new_expert_maps, self.old_expert_maps)
|
||||
self.old_expert_maps = new_expert_maps
|
||||
logger.info("EPLB Process compute complete")
|
||||
|
||||
@@ -124,112 +120,6 @@ class EplbWorker:
|
||||
new_placement[layer_id] = old_placement[layer_id]
|
||||
break
|
||||
|
||||
def compose_expert_update_info_bipartite(self, updated_expert_maps_org,
|
||||
current_expert_maps_org):
|
||||
# transform numpy array to torch tensor
|
||||
updated_expert_maps = updated_expert_maps_org.clone()
|
||||
current_expert_maps = current_expert_maps_org.clone()
|
||||
updated_expert_maps = np.array(updated_expert_maps)
|
||||
current_expert_maps = np.array(current_expert_maps)
|
||||
|
||||
num_layers = current_expert_maps.shape[0]
|
||||
|
||||
for layer_id in range(num_layers):
|
||||
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
|
||||
current_expert_maps_this_layer = current_expert_maps[layer_id]
|
||||
updated_expert_maps_this_layer_org = updated_expert_maps_org[
|
||||
layer_id]
|
||||
|
||||
from typing import Any
|
||||
|
||||
expert_send_info_this_layer: dict[Any, Any] = {}
|
||||
expert_recv_info_this_layer: dict[Any, Any] = {}
|
||||
|
||||
# Guard Clause: if there is no expert weight update, avoid subsequent processing
|
||||
if (np.equal(updated_expert_maps_this_layer,
|
||||
current_expert_maps_this_layer)).all():
|
||||
yield (expert_send_info_this_layer,
|
||||
expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer_org, layer_id)
|
||||
|
||||
# Parse expert_ids each rank needs to receive from other ranks
|
||||
dst_rank_indices, experts_to_recv = np.where(
|
||||
(current_expert_maps_this_layer == -1)
|
||||
& (updated_expert_maps_this_layer != -1))
|
||||
|
||||
# record src ranks for potential transfer
|
||||
src_ranks_set = dict()
|
||||
for idx in range(len(dst_rank_indices)):
|
||||
expert_id = experts_to_recv[idx].item()
|
||||
if expert_id not in src_ranks_set:
|
||||
src_ranks_set[expert_id] = np.where(
|
||||
current_expert_maps_this_layer[:, expert_id] != -1)[0]
|
||||
|
||||
# loop until all experts are scheduled
|
||||
while len(dst_rank_indices) > 0:
|
||||
# construct bipartite graph
|
||||
graph_expert_update: nx.Graph = nx.Graph()
|
||||
for idx in range(len(dst_rank_indices)):
|
||||
dst_rank_id = dst_rank_indices[idx].item()
|
||||
expert_id = experts_to_recv[idx].item()
|
||||
# add src ranks
|
||||
src_rank_ids = src_ranks_set[expert_id]
|
||||
graph_expert_update.add_nodes_from(src_rank_ids,
|
||||
bipartite=0)
|
||||
# add dest rank
|
||||
graph_expert_update.add_node(str(dst_rank_id), bipartite=1)
|
||||
# add edges
|
||||
for src_rank_id in src_rank_ids:
|
||||
graph_expert_update.add_edge(src_rank_id,
|
||||
str(dst_rank_id))
|
||||
|
||||
# graph may not be connected
|
||||
connected_components = list(
|
||||
nx.connected_components(graph_expert_update))
|
||||
all_matches = {}
|
||||
# matching in this loop
|
||||
for i, component in enumerate(connected_components):
|
||||
subgraph = graph_expert_update.subgraph(component)
|
||||
component_matching = nx.bipartite.maximum_matching(
|
||||
subgraph)
|
||||
all_matches.update(component_matching)
|
||||
|
||||
for src_rank, dst_rank in all_matches.items():
|
||||
dst_rank = int(dst_rank)
|
||||
assert src_rank != dst_rank
|
||||
if graph_expert_update.nodes[src_rank]['bipartite'] == 0:
|
||||
# currently not scheduled experts in rank dst_rank
|
||||
experts_v = experts_to_recv[np.where(
|
||||
dst_rank_indices == dst_rank)]
|
||||
# src: src_rank, dest: dst_rank, expert: expert_id
|
||||
expert_id = np.intersect1d(
|
||||
experts_v,
|
||||
np.where(current_expert_maps_this_layer[src_rank]
|
||||
!= -1))[0]
|
||||
|
||||
# record send/rcv pairs
|
||||
if src_rank not in expert_send_info_this_layer:
|
||||
expert_send_info_this_layer[src_rank] = []
|
||||
if dst_rank not in expert_recv_info_this_layer:
|
||||
expert_recv_info_this_layer[dst_rank] = []
|
||||
expert_send_info_this_layer[src_rank].append(
|
||||
(dst_rank, expert_id))
|
||||
expert_recv_info_this_layer[dst_rank].append(
|
||||
(src_rank, expert_id))
|
||||
|
||||
remove_index = np.where(
|
||||
np.logical_and(dst_rank_indices == dst_rank,
|
||||
experts_to_recv == expert_id))
|
||||
|
||||
# update
|
||||
dst_rank_indices = np.delete(dst_rank_indices,
|
||||
remove_index)
|
||||
experts_to_recv = np.delete(experts_to_recv,
|
||||
remove_index)
|
||||
|
||||
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer_org, layer_id)
|
||||
|
||||
# TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases
|
||||
def compose_expert_update_info_greedy(self, updated_expert_maps,
|
||||
current_expert_maps):
|
||||
|
||||
@@ -24,7 +24,7 @@ class DynamicTable:
|
||||
placement_table = None
|
||||
|
||||
|
||||
class DynamicEplb(EplbPolicy):
|
||||
class DefaultEplb(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory.
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
from .policy_dynamic_ep import DynamicEplb
|
||||
from .policy_dynamic_ep_v2 import DynamicEplbV2
|
||||
from .policy_default_eplb import DefaultEplb
|
||||
from .policy_swift_balancer import SwiftBalanceEplb
|
||||
from .policy_flashlb import FlashLB, warm_up
|
||||
from .policy_random import RandomLoadBalance
|
||||
|
||||
@@ -20,9 +20,9 @@ class PolicyFactory:
|
||||
0:
|
||||
RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3
|
||||
1:
|
||||
DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
|
||||
DefaultEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
|
||||
2:
|
||||
DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
|
||||
SwiftBalanceEplb, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
|
||||
3:
|
||||
FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ class DynamicTable:
|
||||
placement_table = None
|
||||
|
||||
|
||||
class DynamicEplbV2(EplbPolicy):
|
||||
class SwiftBalanceEplb(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
Reference in New Issue
Block a user