From f86596a66cc0aff2b05280303212758380f0ec9a Mon Sep 17 00:00:00 2001 From: sherie <963372609@qq.com> Date: Thu, 4 Sep 2025 11:56:29 +0800 Subject: [PATCH] allgather use fusedop. (#2689) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Use 'npu_moe_init_routing_v2' &'npu_moe_token_unpermute' repalce 'npu_moe_init_routing' &‘npu_moe_compute_expert_tokens’& 'npu_moe_finalize_routing' to optimize performance ### Does this PR introduce _any_ user-facing change? | branch| tps| TTFT |TPOT | | --- | --- | --- |--- | |main |733.98 | 280.05 |34.30 | |main+fusedop | 740.33 | 273.34 |33.99 | ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63 Signed-off-by: wangxiaoxin-sherie Co-authored-by: wangxiaoxin-sherie --- tests/e2e/singlecard/ops/test_fused_moe.py | 15 +- tests/ut/ops/test_token_dispatcher.py | 68 +++++---- .../ops/moe_dispatcher/token_dispatcher.py | 143 ++++-------------- 3 files changed, 66 insertions(+), 160 deletions(-) diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index cf13010..bc1309e 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -33,7 +33,7 @@ from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ TokenDispatcherWithAllGather NUM_EXPERTS = [8, 64] -EP_SIZE = [1, 4] +EP_SIZE = [1] TOP_KS = [2, 6] DEVICE = ["npu"] @@ -115,19 +115,6 @@ def test_token_dispatcher_with_all_gather( w1_local = w1 w2_local = w2 - if ep_size > 1: - local_e = e // ep_size - e_ids = torch.arange(local_e * 0, - local_e * (0 + 1), - device=device, - dtype=torch.int32) - expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32) - expert_map[e_ids] = torch.arange(local_e, - device=device, - dtype=torch.int32) - w1_local = w1[e_ids] - w2_local = w2[e_ids] - score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) topk_ids = topk_ids.to(torch.int32) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 9de8a13..6782f45 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase): self.dispatcher = TokenDispatcherWithAllGather(**kwargs) # Mock NPU functions - self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing') - self.mock_moe_init_routing = self.patcher_moe_init_routing.start() - self.mock_moe_init_routing.return_value = ( + self.patcher_npu_moe_init_routing_v2 = patch( + 'torch_npu.npu_moe_init_routing_v2') + self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start( + ) + self.mock_npu_moe_init_routing_v2.return_value = ( torch.randn(6, 128), # sorted_hidden_states torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx - torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx - ) - - self.patcher_moe_compute_expert_tokens = patch( - 'torch_npu.npu_moe_compute_expert_tokens') - self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start( - ) - self.mock_moe_compute_expert_tokens.return_value = torch.tensor( - [3, 3]) # expert_tokens - - self.patcher_moe_finalize_routing = patch( - 'torch_npu.npu_moe_finalize_routing') - self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start( - ) - self.mock_moe_finalize_routing.return_value = torch.randn(3, 128) + torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx + torch.tensor([0, 1, 0, 1, 0, 1])) self.row_idx = torch.arange(10, dtype=torch.int32) + self.patcher_npu_moe_token_unpermute = patch( + 'torch_npu.npu_moe_token_unpermute') + self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start( + ) + self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128) def tearDown(self): - self.patcher_moe_init_routing.stop() - self.patcher_moe_compute_expert_tokens.stop() - self.patcher_moe_finalize_routing.stop() + self.patcher_npu_moe_init_routing_v2.stop() + self.patcher_npu_moe_token_unpermute.stop() def test_token_dispatch_without_expert_map(self): hidden_states = torch.randn(3, 128) @@ -207,10 +200,25 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_ids, self.row_idx, None) # Verify npu_moe_init_routing is called - self.mock_moe_init_routing.assert_called_once() - args, kwargs = self.mock_moe_init_routing.call_args + self.mock_npu_moe_init_routing_v2.assert_called_once() + args, kwargs = self.mock_npu_moe_init_routing_v2.call_args - self.assertEqual(results["group_list_type"], 0) + self.assertEqual(results["group_list_type"], 1) + + def test_token_dispatch_with_expert_map(self): + self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) + hidden_states = torch.randn(3, 128) + topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) + topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) + + results = self.dispatcher.token_dispatch(hidden_states, topk_weights, + topk_ids, self.row_idx, None) + + # Verify npu_moe_init_routing is called + self.mock_npu_moe_init_routing_v2.assert_called_once() + args, kwargs = self.mock_npu_moe_init_routing_v2.call_args + + self.assertEqual(results["group_list_type"], 1) def test_token_dispatch_with_quant(self): kwargs = { @@ -230,7 +238,7 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights, topk_ids, self.row_idx, None) - self.assertEqual(results["group_list_type"], 0) + self.assertEqual(results["group_list_type"], 1) def test_token_combine_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) @@ -242,9 +250,7 @@ class TestTokenDispatcherWithAllGather(TestBase): hidden_states = torch.randn(6, 128) final_hidden_states = self.dispatcher.token_combine(hidden_states) - - # Verify index_add_ is applied correctly - self.assertEqual(final_hidden_states.shape, (3, 128)) + self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_combine_without_expert_map(self): self.dispatcher.with_quant = False @@ -260,10 +266,10 @@ class TestTokenDispatcherWithAllGather(TestBase): final_hidden_states = self.dispatcher.token_combine(hidden_states) # Verify npu_moe_finalize_routing is called - self.mock_moe_finalize_routing.assert_called_once() - args, kwargs = self.mock_moe_finalize_routing.call_args + self.mock_npu_moe_token_unpermute.assert_called_once() + args, kwargs = self.mock_npu_moe_token_unpermute.call_args - self.assertEqual(final_hidden_states.shape, (3, 128)) + self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_dispatch_with_router_weight(self): self.dispatcher.apply_router_weight_on_input = True diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 855faad..90b2209 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -338,8 +338,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.original_shape = hidden_states.shape num_tokens = hidden_states.shape[:-1].numel() - dtype = hidden_states.dtype - device = hidden_states.device self.expert_map = expert_map self.topk_weights = topk_weights self.topk_ids = topk_ids @@ -353,67 +351,31 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): ), "Only support topk=1 when `apply_router_weight_on_input` is True" hidden_states = hidden_states * \ topk_weights.to(hidden_states.dtype) - if expert_map is not None: - # Generate token indices and flatten - token_indices = (torch.arange( - num_tokens, device=device, - dtype=torch.int64).unsqueeze(1).expand(-1, - self.top_k).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = expert_map[experts_flat] - - # Filter valid token-expert pairs - self.mask = local_experts_flat != -1 - filtered_weights = torch.where( - self.mask, weights_flat, - torch.zeros_like(weights_flat)).to(dtype) - filtered_experts = torch.where( - self.mask, local_experts_flat, - torch.full_like(local_experts_flat, - self.num_experts_local)).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts.view(torch.float32)) - self.sorted_token_indices = token_indices[sort_indices] - self.sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(self.num_experts_local + 1, - device=device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), - ones) - token_counts = token_counts[:self.num_experts_local] - - # Rearrange hidden_states - sorted_hidden_states = hidden_states[self.sorted_token_indices] - if self.with_quant: - group_list_type = 1 - expert_tokens = token_counts - else: - expert_tokens = torch.cumsum(token_counts, - dim=0, - dtype=torch.int64) - group_list_type = 0 + global_num_experts = len(expert_map) + mask = (expert_map[topk_ids] != -1) + self.topk_weights = topk_weights * mask + first_expert_idx = get_ep_group( + ).rank_in_group * self.num_experts_local + last_expert_idx = first_expert_idx + self.num_experts_local else: - active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens - sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=active_num) + first_expert_idx = 0 + last_expert_idx = self.num_experts_local + global_num_experts = self.num_experts_local - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, self.num_experts_local) - expert_tokens = expert_tokens.to(torch.int64) - group_list_type = 0 + sorted_hidden_states, self.expanded_row_idx, expert_tokens, _ = ( + torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + active_num=num_tokens * self.top_k, + expert_num=global_num_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[first_expert_idx, last_expert_idx], + quant_mode=-1, + )) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 1 # `count` mode return { "group_list_type": group_list_type, "hidden_states": sorted_hidden_states, @@ -424,61 +386,12 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): hidden_states: torch.Tensor, bias: torch.Tensor = None): assert self.original_shape is not None - dtype = hidden_states.dtype - device = hidden_states.device - if self.expert_map is not None: - assert self.mask is not None - assert self.sorted_token_indices is not None - assert self.sorted_weights is not None - - weighted_down_out = hidden_states * \ - self.sorted_weights.unsqueeze(1) - - final_hidden_states = torch.zeros(*self.original_shape, - device=hidden_states.device, - dtype=hidden_states.dtype) - - # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # This created multiple NaN and index_add_ will mix them up which harms accuracy - # remove this mask and filter after it being fixed - num_valid_tokens = self.mask.sum() - valid_token_mask = torch.arange( - 0, self.sorted_token_indices.shape[0], - device=device).unsqueeze(1) < num_valid_tokens - valid_output = torch.where( - valid_token_mask, weighted_down_out, - torch.zeros_like(weighted_down_out)).to(dtype) - final_hidden_states.index_add_(0, self.sorted_token_indices, - valid_output) - else: - if self.with_quant: - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=self.topk_weights, - expanded_src_to_dst_row=self.expanded_row_idx, - export_for_source_row=self.topk_ids, - ) - if len(self.original_shape) == 3: - final_hidden_states = final_hidden_states.view( - self.original_shape) - else: - scales = torch.ones_like( - self.topk_weights - ) if self.apply_router_weight_on_input else self.topk_weights - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=scales, - expanded_src_to_dst_row=self.expanded_row_idx, - export_for_source_row=self.topk_ids, - ) + final_hidden_states = torch_npu.npu_moe_token_unpermute( + permuted_tokens=hidden_states, + sorted_indices=self.expanded_row_idx, + probs=self.topk_weights) + if len(self.original_shape) == 3: + final_hidden_states = final_hidden_states.view(self.original_shape) return final_hidden_states