From 6a4ec186e731b9516235f4fd30b5b98227513fe7 Mon Sep 17 00:00:00 2001 From: s30076806 Date: Wed, 27 Aug 2025 09:13:31 +0800 Subject: [PATCH] [Qwen-moe] Remove the minor operation arange (#2373) ### What this PR does / why we need it? Integrate the arange operator to reduce the time spent and improve performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/56dcf4e7e965e34043acf20ca4e4aceda21d41ec --------- Signed-off-by: s30076806 --- tests/e2e/singlecard/ops/test_fused_moe.py | 12 +++++- tests/ut/ops/test_fused_ops.py | 2 +- tests/ut/quantization/test_w8a8.py | 46 +++++++++++----------- tests/ut/quantization/test_w8a8_dynamic.py | 7 ++++ vllm_ascend/ops/common_fused_moe.py | 2 +- vllm_ascend/ops/fused_moe.py | 31 ++++----------- vllm_ascend/ops/layers/experts_selector.py | 27 ++++++++++--- vllm_ascend/quantization/w4a8_dynamic.py | 3 +- vllm_ascend/quantization/w8a8_dynamic.py | 29 ++++---------- 9 files changed, 80 insertions(+), 79 deletions(-) diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index 21e0a4d..ab673a4 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -92,8 +92,15 @@ def test_fused_experts( score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) topk_ids = topk_ids.to(torch.int32) + row_idx = (torch.arange( + 0, + m * topk, + device=device, + dtype=torch.int32, + ).view(topk, -1).permute(1, 0).contiguous()) - output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map) + output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk, + e_map) torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map) # TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1) @@ -148,7 +155,7 @@ def test_select_experts( mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=topk, @@ -169,6 +176,7 @@ def test_select_experts( assert topk_weights.shape == (m, topk) assert topk_ids.shape == (m, topk) assert topk_ids.dtype == torch.int32 + assert row_idx.shape == (m, topk) @pytest.mark.parametrize("device", DEVICE) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 42370eb..6db32e6 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -405,7 +405,7 @@ class TestExpertsSelector: x = torch.randn(8, 2) router_logits = torch.randn(8, 2) - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=x, router_logits=router_logits, top_k=2, diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 63b017c..669f2b9 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -719,12 +719,12 @@ class TestSelectExperts(TestBase): def test_softmax_scoring(self): """Test softmax scoring function""" - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="softmax") + weights, ids, _ = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="softmax") self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -732,12 +732,12 @@ class TestSelectExperts(TestBase): def test_sigmoid_scoring(self): """Test sigmoid scoring function""" - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="sigmoid") + weights, ids, _ = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid") self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -760,13 +760,13 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.long)) - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2) + weights, ids, _ = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -780,7 +780,7 @@ class TestSelectExperts(TestBase): self.num_experts) e_score_correction_bias = torch.randn(self.num_experts) - weights, ids = select_experts( + weights, ids, _ = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -803,7 +803,7 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.int32)) - weights, ids = select_experts( + weights, ids, _ = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -824,7 +824,7 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.long)) - weights, _ = select_experts( + weights, ids, _ = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -844,7 +844,7 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.long)) - weights, ids = select_experts( + weights, ids, _ = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py index 59ab604..0e07eb1 100644 --- a/tests/ut/quantization/test_w8a8_dynamic.py +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -55,6 +55,12 @@ class TestAscendW8A8FusedMoEMethod(TestBase): torch.randn(self.num_tokens), ) mock_moe_finalize_routing.return_value = self.placeholder + row_idx_len = self.num_tokens * 8 + row_idx = (torch.arange( + 0, + row_idx_len, + dtype=torch.int32, + ).view(8, -1).permute(1, 0).contiguous()) result = fused_experts_with_all2all( hidden_states=self.placeholder, @@ -64,6 +70,7 @@ class TestAscendW8A8FusedMoEMethod(TestBase): w2_scale=self.placeholder, topk_weights=self.placeholder, topk_ids=self.placeholder, + row_idx=row_idx, top_k=8, expert_map=expert_map, ep_group=ep_group, diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index cc0f735..ffc1dea 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -130,7 +130,7 @@ def forward_oot( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 70e87dc..f02d146 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -326,6 +326,7 @@ def fused_experts_with_all2all( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, ep_group: GroupCoordinator = None, @@ -336,17 +337,10 @@ def fused_experts_with_all2all( num_tokens, _ = hidden_states.shape num_experts = w1.shape[0] - device = hidden_states.device if expert_map is not None: global_num_experts = len(expert_map) local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -380,12 +374,6 @@ def fused_experts_with_all2all( hidden_states = hidden_states[sorted_idx] else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -459,6 +447,7 @@ def fused_experts_with_all2all_buffer( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, max_model_len: int, global_batch_size: int, @@ -470,14 +459,10 @@ def fused_experts_with_all2all_buffer( hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) num_tokens, _ = hidden_states.shape - device = hidden_states.device global_num_experts = len(expert_map) local_num_experts = global_num_experts // ep_group.world_size row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, - device=device).view(top_k, - -1).permute(1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -690,6 +675,7 @@ def fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, apply_router_weight_on_input: bool = False, @@ -781,12 +767,6 @@ def fused_experts( # Rearrange hidden_states sorted_hidden_states = hidden_states[sorted_token_indices] else: - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) active_num = max_num_tokens if max_num_tokens is not None else num_tokens sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, @@ -908,7 +888,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): **kwargs, ) -> torch.Tensor: - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -952,6 +932,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map) elif MOE_ALL2ALL_BUFFER: @@ -961,6 +942,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, max_model_len=self.max_model_len, global_batch_size=self.global_batch_size, @@ -982,6 +964,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map, ep_group=get_ep_group()) diff --git a/vllm_ascend/ops/layers/experts_selector.py b/vllm_ascend/ops/layers/experts_selector.py index c906cf3..11524ac 100644 --- a/vllm_ascend/ops/layers/experts_selector.py +++ b/vllm_ascend/ops/layers/experts_selector.py @@ -20,6 +20,17 @@ import torch import torch_npu +def return_row_idx(hidden_states, top_k): + num_tokens = hidden_states.shape[0] + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=hidden_states.device).view( + top_k, -1).permute(1, 0).contiguous()) + return row_idx + + def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -56,7 +67,8 @@ def select_experts(hidden_states: torch.Tensor, topk_ids: selected expert IDs of shape (num_tokens, top_k). """ - topk_weights, topk_ids = _select_experts_with_fusion_ops( + topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops( + hidden_states=hidden_states, router_logits=router_logits, top_k=top_k, use_grouped_topk=use_grouped_topk, @@ -83,7 +95,9 @@ def select_experts(hidden_states: torch.Tensor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts, ) - return topk_weights, topk_ids + if row_idx is None: + row_idx = return_row_idx(hidden_states, top_k) + return topk_weights, topk_ids, row_idx def _native_grouped_topk( @@ -156,6 +170,7 @@ def _select_expert_use_group_topk( def _select_experts_with_fusion_ops( + hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, @@ -168,7 +183,7 @@ def _select_experts_with_fusion_ops( global_num_experts: int = -1, is_unquantized: bool = False): - topk_weights, topk_ids = None, None + topk_weights, topk_ids, row_idx = None, None, None # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern is_deepseek_v3_r1 = global_num_experts == 256 if is_deepseek_v3_r1: @@ -186,14 +201,14 @@ def _select_experts_with_fusion_ops( # y2_flag=False, # old api; should the third output be output routed_scaling_factor=1, eps=float(1e-20)) - + row_idx = return_row_idx(hidden_states, top_k) if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( + topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax( x=router_logits, finished=None, k=top_k) topk_ids = topk_ids.to(torch.int32) topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - return topk_weights, topk_ids + return topk_weights, topk_ids, row_idx def _native_select_experts( diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index f7d838d..a724615 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -268,7 +268,7 @@ class AscendW4A8DynamicFusedMoEMethod: 1] == global_num_experts, "Number of global experts mismatch" # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -334,6 +334,7 @@ class AscendW4A8DynamicFusedMoEMethod: w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map, ep_group=self.ep_group, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 21615f3..cba090b 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -365,14 +365,9 @@ def fused_experts_with_mc2( return hidden_states, shared_output -def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts): +def init_routing_quant(hidden_states, top_k, topk_ids, row_idx, + global_num_experts): num_tokens, _ = hidden_states.shape - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=hidden_states.device).view( - top_k, -1).permute(1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -398,6 +393,7 @@ def fused_experts_with_all2all( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, ep_group: GroupCoordinator = None, @@ -431,7 +427,7 @@ def fused_experts_with_all2all( ) else: quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant( - hidden_states, top_k, topk_ids, global_num_experts) + hidden_states, top_k, topk_ids, row_idx, global_num_experts) gather_sizes = global_expert_tokens.new_empty( global_expert_tokens.shape[0]) @@ -463,12 +459,6 @@ def fused_experts_with_all2all( expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -627,6 +617,7 @@ def fused_experts(hidden_states: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, expert_map: torch.Tensor = None): original_shape = hidden_states.shape @@ -677,12 +668,6 @@ def fused_experts(hidden_states: torch.Tensor, hidden_states = hidden_states[sorted_token_indices] group_list_type = 1 else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -903,7 +888,7 @@ class AscendW8A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -973,6 +958,7 @@ class AscendW8A8DynamicFusedMoEMethod: w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map) else: @@ -988,6 +974,7 @@ class AscendW8A8DynamicFusedMoEMethod: w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map, ep_group=self.ep_group,