[EPLB][Bugfix] Dispatch Allgather use log2phy if enable eplb (#5933)
### What this PR does / why we need it?
1. Move the logic of expert mapping forward to prevent shotgun changes
2. Disable the update of expert map.
### How was this patch tested?
a2
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| GPQA_diamond | 53064e | accuracy | gen | 73.23 |
a3
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 83.33 |
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
@@ -472,23 +472,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
def test_token_dispatch_with_log2phy(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
log2phy = torch.tensor([1, 0, 3, 2])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
@@ -33,12 +33,7 @@ class VllmEplbAdaptor(EplbAdaptor):
|
||||
self.rank_id = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
self.param_dict = dict(self.model.named_parameters())
|
||||
if self.model.config.model_type == "qwen3_moe":
|
||||
self.num_dense_layers = 0
|
||||
self.global_expert_num = self.model.config.num_experts
|
||||
else:
|
||||
self.num_dense_layers = self.model.config.first_k_dense_replace
|
||||
self.global_expert_num = self.model.config.n_routed_experts
|
||||
self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0)
|
||||
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
|
||||
|
||||
for i in range(self.num_dense_layers,
|
||||
@@ -64,17 +59,10 @@ class VllmEplbAdaptor(EplbAdaptor):
|
||||
else:
|
||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||
|
||||
self.expert_map_per_layer = dict(
|
||||
) # reference to expert map on device for expert map update
|
||||
self.expert_map_per_layer_cpu = dict(
|
||||
) # copy of expert map on CPU to avoid device synchronize frequently
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \
|
||||
self.model.get_expert_map(self.num_dense_layers + layer_idx)
|
||||
|
||||
# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
|
||||
num_buffer_tensor = torch.where(
|
||||
self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
|
||||
num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts
|
||||
self.buffer_tensor_list: list[list[Any]] = [
|
||||
[] for _ in range(num_buffer_tensor)
|
||||
]
|
||||
@@ -88,8 +76,6 @@ class VllmEplbAdaptor(EplbAdaptor):
|
||||
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
|
||||
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
|
||||
|
||||
self.all_topk_ids = []
|
||||
|
||||
def init_buffer_tensor(self, num_buffer_tensor):
|
||||
for buffer_id in range(num_buffer_tensor):
|
||||
for name in self.expert_weight_names:
|
||||
@@ -169,7 +155,6 @@ class VllmEplbAdaptor(EplbAdaptor):
|
||||
json.dump(record, f, indent=4)
|
||||
|
||||
def do_update_expert_map(self, layer_id, updated_expert_map):
|
||||
self.expert_map_per_layer[layer_id].copy_(updated_expert_map)
|
||||
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
|
||||
|
||||
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
|
||||
|
||||
@@ -38,7 +38,6 @@ class EplbUpdator:
|
||||
def set_adaptor(self, adaptor):
|
||||
self.adaptor = adaptor
|
||||
self.num_moe_layers = self.adaptor.num_moe_layers
|
||||
self.global_expert_num = self.adaptor.global_expert_num
|
||||
|
||||
def init_eplb(self, expert_map_path, process):
|
||||
self.rank_id = dist.get_rank()
|
||||
|
||||
@@ -150,7 +150,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
|
||||
@@ -105,7 +105,6 @@ class MoECommMethod(ABC):
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
@@ -128,12 +127,15 @@ class MoECommMethod(ABC):
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
before_dispatch_evt = torch.npu.current_stream().record_event()
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=self.moe_config.
|
||||
global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
@@ -278,7 +280,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
|
||||
@@ -80,7 +80,6 @@ class MoETokenDispatcher(ABC):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -188,7 +187,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -197,10 +195,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
self.with_quant = with_quant
|
||||
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
|
||||
topk_ids, expert_map,
|
||||
mc2_mask,
|
||||
@@ -309,7 +303,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -429,7 +422,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -439,9 +431,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
(
|
||||
permutated_local_input_tokens,
|
||||
reversed_local_input_permutation_mapping,
|
||||
|
||||
Reference in New Issue
Block a user