diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 9609b97d..e7577c2a 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -472,23 +472,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase): self.assertIsNotNone(result.dynamic_scale) self.assertEqual(result.group_list_type, 1) - def test_token_dispatch_with_log2phy(self): - hidden_states = torch.randn(8, 16) - topk_weights = torch.rand(8, 4) - topk_ids = torch.randint(0, 4, (8, 2)).long() - expert_map = torch.tensor([0, 1, 2, 3]) - log2phy = torch.tensor([1, 0, 3, 2]) - - self.dispatcher.expert_ids_per_ep_rank = torch.tensor( - [0, 1], dtype=torch.int32) - self.dispatcher.local_expert_indices = [0, 1] - - result = self.dispatcher.token_dispatch(hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - log2phy=log2phy) - - self.assertIsNotNone(result.hidden_states) - self.assertIsNotNone(result.group_list) - self.assertEqual(result.group_list_type, 1) diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 94cedc1d..6b5cce8d 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -33,12 +33,7 @@ class VllmEplbAdaptor(EplbAdaptor): self.rank_id = dist.get_rank() self.world_size = dist.get_world_size() self.param_dict = dict(self.model.named_parameters()) - if self.model.config.model_type == "qwen3_moe": - self.num_dense_layers = 0 - self.global_expert_num = self.model.config.num_experts - else: - self.num_dense_layers = self.model.config.first_k_dense_replace - self.global_expert_num = self.model.config.n_routed_experts + self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0) self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers for i in range(self.num_dense_layers, @@ -64,17 +59,10 @@ class VllmEplbAdaptor(EplbAdaptor): else: self.expert_weight_names = ["w13_weight", "w2_weight"] - self.expert_map_per_layer = dict( - ) # reference to expert map on device for expert map update self.expert_map_per_layer_cpu = dict( ) # copy of expert map on CPU to avoid device synchronize frequently - for layer_idx in range(self.num_moe_layers): - self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \ - self.model.get_expert_map(self.num_dense_layers + layer_idx) - # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved - num_buffer_tensor = torch.where( - self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel() + num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts self.buffer_tensor_list: list[list[Any]] = [ [] for _ in range(num_buffer_tensor) ] @@ -88,8 +76,6 @@ class VllmEplbAdaptor(EplbAdaptor): self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \ self.model.get_log2phy_map(self.num_dense_layers + layer_idx) - self.all_topk_ids = [] - def init_buffer_tensor(self, num_buffer_tensor): for buffer_id in range(num_buffer_tensor): for name in self.expert_weight_names: @@ -169,7 +155,6 @@ class VllmEplbAdaptor(EplbAdaptor): json.dump(record, f, indent=4) def do_update_expert_map(self, layer_id, updated_expert_map): - self.expert_map_per_layer[layer_id].copy_(updated_expert_map) self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) def do_update_expert_weight(self, layer_id, local_expert_to_replace, diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 15468e9f..b76a4bbd 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -38,7 +38,6 @@ class EplbUpdator: def set_adaptor(self, adaptor): self.adaptor = adaptor self.num_moe_layers = self.adaptor.num_moe_layers - self.global_expert_num = self.adaptor.global_expert_num def init_eplb(self, expert_map_path, process): self.rank_id = dist.get_rank() diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 0aed4a61..2c169f5c 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -150,7 +150,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - global_num_experts=global_num_experts, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, dynamic_eplb=self.dynamic_eplb, diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index b5427fa7..829f8f5f 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -105,7 +105,6 @@ class MoECommMethod(ABC): use_int8_w8a8: bool = False, use_int4_w4a8: bool = False, use_int4_w4a16: bool = False, - global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[list[torch.Tensor]] = None, w2_scale: Optional[list[torch.Tensor]] = None, @@ -128,12 +127,15 @@ class MoECommMethod(ABC): assert moe_comm_method is not None, "Missing communication context" before_dispatch_evt = torch.npu.current_stream().record_event() + # Apply log2phy if needed + if log2phy is not None: + topk_ids = log2phy[topk_ids] + dispatch_results = self.token_dispatcher.token_dispatch( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, - log2phy=log2phy, global_redundant_expert_num=self.moe_config. global_redundant_expert_num, mc2_mask=mc2_mask, @@ -278,7 +280,6 @@ class FusedMC2CommImpl(MoECommMethod): use_int8_w8a8: bool = False, use_int4_w4a8: bool = False, use_int4_w4a16: bool = False, - global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[list[torch.Tensor]] = None, w2_scale: Optional[list[torch.Tensor]] = None, diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 4b699e42..b046e953 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -80,7 +80,6 @@ class MoETokenDispatcher(ABC): topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, @@ -188,7 +187,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, @@ -197,10 +195,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): pertoken_scale: Optional[torch.Tensor] = None): self.with_quant = with_quant - # Apply log2phy if needed - if log2phy is not None: - topk_ids = log2phy[topk_ids] - kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, @@ -309,7 +303,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, @@ -429,7 +422,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, @@ -439,9 +431,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): self.with_quant = with_quant self.hidden_shape = hidden_states.shape - if log2phy is not None: - topk_ids = log2phy[topk_ids] - ( permutated_local_input_tokens, reversed_local_input_permutation_mapping,