From e14f2ef6690a04a3359220c94d6ea3eed7af5049 Mon Sep 17 00:00:00 2001 From: shiyuan680 <72335504+shiyuan680@users.noreply.github.com> Date: Thu, 14 Aug 2025 11:50:53 +0800 Subject: [PATCH] refactor select_experts of moe module (#2150) ### What this PR does / why we need it? this pr refactor select_experts of moe module i merge implementations of quantitative and non-quantitative method in a new class use such as vllm like ExpertsSelector.select_experts ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? test in qwen3-moe and all ut. - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/e18859298d109870b22cb5b8672d1078818e268d Signed-off-by: yangcheng Co-authored-by: yangcheng (AJ) --- tests/e2e/singlecard/ops/test_fused_moe.py | 5 +- tests/ut/ops/test_fused_ops.py | 26 ++ tests/ut/quantization/test_w8a8.py | 25 +- vllm_ascend/ops/common_fused_moe.py | 9 +- vllm_ascend/ops/fused_moe.py | 181 ++------------ vllm_ascend/ops/layers/__init__.py | 0 vllm_ascend/ops/layers/experts_selector.py | 269 +++++++++++++++++++++ vllm_ascend/quantization/w4a8_dynamic.py | 44 ++-- vllm_ascend/quantization/w8a8.py | 124 +--------- vllm_ascend/quantization/w8a8_dynamic.py | 46 ++-- 10 files changed, 359 insertions(+), 370 deletions(-) create mode 100644 vllm_ascend/ops/layers/__init__.py create mode 100644 vllm_ascend/ops/layers/experts_selector.py diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index d04f3a6..21e0a4d 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -26,7 +26,8 @@ import pytest import torch from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.fused_moe import fused_experts, select_experts +from vllm_ascend.ops.fused_moe import fused_experts +from vllm_ascend.ops.layers.experts_selector import select_experts NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] @@ -142,7 +143,7 @@ def test_select_experts( dtype=torch.int32) custom_routing_function.return_value = (mock_weights, mock_ids) - with patch("vllm_ascend.ops.fused_moe.native_grouped_topk" + with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk" ) as mock_native_grouped_topk: mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 8c16ec4..ec3037f 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm_ascend.ascend_forward_context import _get_fused_moe_state from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) +from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 adapt_patch(True) @@ -389,3 +390,28 @@ class TestAscendUnquantizedFusedMoEMethod: assert result.shape == (16, 2) else: assert result.shape == x.shape + + +class TestExpertsSelector: + + @pytest.mark.parametrize("global_num_experts", [[256], [128]]) + def test_select_experts(self, mock_dist_env, mock_moe_env, + global_num_experts): + + x = torch.randn(8, 2) + router_logits = torch.randn(8, 2) + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=2, + use_grouped_topk=False, + renormalize=True, + topk_group=None, + num_expert_group=None, + custom_routing_function=None, + scoring_func="softmax", + e_score_correction_bias=None, + global_num_experts=global_num_experts) + + assert topk_weights.shape == (8, 2) + assert topk_ids.shape == (8, 2) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 392355a..63b017c 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -5,12 +5,13 @@ import torch from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk, + select_experts) from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod, fused_experts, fused_experts_310p, - native_grouped_topk, - quant_per_tensor, select_experts) + quant_per_tensor) class TestQuantPerTensor(TestBase): @@ -772,7 +773,7 @@ class TestSelectExperts(TestBase): self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('vllm_ascend.quantization.w8a8.native_grouped_topk') + @patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk') def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): """Test grouped topk with expert score correction bias""" mock_grouped_topk.return_value = torch.ones(self.num_tokens, @@ -868,9 +869,9 @@ class TestNativeGroupedTopkPartialMock(TestBase): with patch('torch.topk', return_value=(None, expected_topk_indices)) as mock_topk: - result = native_grouped_topk(topk_weights=topk_weights, - num_expert_group=2, - topk_group=2) + result = _native_grouped_topk(topk_weights=topk_weights, + num_expert_group=2, + topk_group=2) mock_topk.assert_called_once() @@ -885,9 +886,9 @@ class TestNativeGroupedTopkPartialMock(TestBase): expected_topk_indices = torch.tensor([[0], [1]]) with patch('torch.topk', return_value=(None, expected_topk_indices)): - result = native_grouped_topk(topk_weights=topk_weights, - num_expert_group=2, - topk_group=1) + result = _native_grouped_topk(topk_weights=topk_weights, + num_expert_group=2, + topk_group=1) expected_result = torch.tensor( [[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0], @@ -900,7 +901,7 @@ class TestNativeGroupedTopkPartialMock(TestBase): expected_topk_indices = torch.tensor([[0], [0]]) with patch('torch.topk', return_value=(None, expected_topk_indices)): - result = native_grouped_topk(topk_weights=topk_weights, - num_expert_group=1, - topk_group=1) + result = _native_grouped_topk(topk_weights=topk_weights, + num_expert_group=1, + topk_group=1) self.assertTrue(result.numel() > 0) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index b97aef7..19a86a7 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts, - unified_fused_experts) +from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts +from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.utils import is_310p original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ @@ -59,7 +59,7 @@ def forward_oot( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts: Optional[int] = None, + global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -69,7 +69,6 @@ def forward_oot( logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: topk_weights, topk_ids = select_experts( - global_num_experts=global_num_experts, hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -80,7 +79,7 @@ def forward_oot( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - ) + global_num_experts=global_num_experts) if topk_ids.shape[1] < top_k or is_310p(): assert global_num_experts is not None diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index d1d3a3a..c772124 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -46,6 +46,7 @@ from vllm_ascend.distributed.communication_op import \ from vllm_ascend.distributed.moe_comm_method import MoECommMethod from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.ops.sequence_parallel import MetadataForPadding @@ -920,143 +921,6 @@ def fused_experts( return final_hidden_states -def native_grouped_topk( - topk_weights: torch.Tensor, - num_expert_group: Optional[int], - topk_group: Optional[int], -): - topk_group = 0 if topk_group is None else topk_group - num_expert_group = 0 if num_expert_group is None else num_expert_group - - num_token = topk_weights.shape[0] - grouped_weights = topk_weights.view(num_token, num_expert_group, - -1).max(dim=-1).values - topk_group_indices = torch.topk(grouped_weights.to(torch.float32), - k=topk_group, - dim=-1, - sorted=False)[1] - topk_group_mask = torch.zeros_like(grouped_weights) - topk_group_mask.scatter_(1, topk_group_indices, 1) - topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) - topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) - - return topk_weights - - -def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts: Optional[torch.Tensor] = None -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Select top-k experts based on router logits. - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - router_logits: Router logits of shape (num_tokens, num_experts). - top_k: Number of experts to select. - use_grouped_topk: Whether to group experts before selecting top-k. - renormalize: Whether to renormalize the routing weights. - topk_group: Number of expert groups to select from. - num_expert_group: Number of experts in each group. - custom_routing_function: Custom routing function. - scoring_func: Scoring function to use. - e_score_correction_bias: Correction bias to apply to expert scores. - - Returns: - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - - Raises: - ValueError: If an unsupported scoring function is provided. - """ - - def _renormalize_topk_weights( - topk_weights: torch.Tensor, - renormalize: bool, - ): - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, - keepdim=True) - return topk_weights - - if scoring_func == "softmax": - # NOTE: vLLM use dtype=torch.float here - if not use_grouped_topk and custom_routing_function is None: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( - x=router_logits, finished=None, k=top_k) - topk_ids = topk_ids.to(torch.int32) - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - return topk_weights, topk_ids - - topk_weights = router_logits.softmax(dim=-1) - elif scoring_func == "sigmoid": - topk_weights = router_logits.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use biased - # scores for expert selection but original scores for routing weights - original_weights = topk_weights - topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) - - # TODO: Change to npu_group_topk when the latest CANN and NNAL is available - # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) - topk_weights = native_grouped_topk(topk_weights, num_expert_group, - topk_group) - # TODO bfloat16 is not supported in torch.topk with ge graph. - if e_score_correction_bias is not None: - topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_weights.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False) - topk_ids = topk_ids.to(torch.int32) - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - return topk_weights, topk_ids - - if custom_routing_function is not None: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - global_num_experts=global_num_experts) - # Required by npu_moe_init_routing - topk_ids = topk_ids.to(torch.int32) - return topk_weights, topk_ids - - topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) - topk_weights = topk_weights.to(hidden_states.dtype) - - # Required by npu_moe_init_routing - topk_ids = topk_ids.to(torch.int32) - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - - return topk_weights, topk_ids - - class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): def __init__(self, moe: FusedMoEConfig = None): @@ -1111,36 +975,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): **kwargs, ) -> torch.Tensor: - is_deepseek_v3_r1 = global_num_experts == 256 - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + is_unquantized=True) topk_weights = topk_weights.to(x.dtype) # this is a naive implementation for experts load balance so as diff --git a/vllm_ascend/ops/layers/__init__.py b/vllm_ascend/ops/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/ops/layers/experts_selector.py b/vllm_ascend/ops/layers/experts_selector.py new file mode 100644 index 0000000..c906cf3 --- /dev/null +++ b/vllm_ascend/ops/layers/experts_selector.py @@ -0,0 +1,269 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Callable, Optional + +import torch +import torch_npu + + +def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None, + is_unquantized: bool = False, + global_num_experts: int = -1): + """ + Fused experts with select experts. + + Args: + router_logits: router logits of shape (num_tokens, hidden_size). + hidden_states: Hidden states of shape (num_tokens, hidden_size). + top_k: number of top k experts. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + indices_type: dtype of indices + is_unquantized: Whether the data are unquantized. + global_num_experts: Global number of experts. + + Returns: + topk_weights: router weights of shape (num_tokens, top_k). + topk_ids: selected expert IDs of shape (num_tokens, top_k). + """ + + topk_weights, topk_ids = _select_experts_with_fusion_ops( + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + global_num_experts=global_num_experts, + is_unquantized=is_unquantized) + + if topk_weights is None: + topk_weights, topk_ids = _native_select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + return topk_weights, topk_ids + + +def _native_grouped_topk( + topk_weights: torch.Tensor, + num_expert_group: Optional[int], + topk_group: Optional[int], +): + topk_group = 0 if topk_group is None else topk_group + num_expert_group = 0 if num_expert_group is None else num_expert_group + + num_token = topk_weights.shape[0] + grouped_weights = topk_weights.view(num_token, num_expert_group, + -1).max(dim=-1).values + topk_group_indices = torch.topk(grouped_weights.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] + topk_group_mask = torch.zeros_like(grouped_weights) + topk_group_mask.scatter_(1, topk_group_indices, 1) + topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) + topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) + + return topk_weights + + +def _renormalize_topk_weights( + topk_weights: torch.Tensor, + renormalize: bool, +): + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights + + +def _select_expert_use_group_topk( + topk_weights: torch.Tensor, topk_group: Optional[int], + renormalize: bool, top_k: int, num_expert_group: Optional[int], + e_score_correction_bias: Optional[torch.Tensor]): + assert topk_group is not None + assert num_expert_group is not None + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_weights = topk_weights + topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) + + # TODO: Change to npu_group_topk when the latest CANN and NNAL is available + # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) + topk_weights = _native_grouped_topk(topk_weights, num_expert_group, + topk_group) + # TODO bfloat16 is not supported in torch.topk with ge graph. + if e_score_correction_bias is not None: + topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_weights.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False) + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + return topk_weights, topk_ids + + +def _select_experts_with_fusion_ops( + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + e_score_correction_bias: Optional[torch.Tensor], + topk_group: Optional[int], + num_expert_group: Optional[int], + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + global_num_experts: int = -1, + is_unquantized: bool = False): + + topk_weights, topk_ids = None, None + # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern + is_deepseek_v3_r1 = global_num_experts == 256 + if is_deepseek_v3_r1: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + + if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( + x=router_logits, finished=None, k=top_k) + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + + return topk_weights, topk_ids + + +def _native_select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Select top-k experts based on router logits. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + router_logits: Router logits of shape (num_tokens, num_experts). + top_k: Number of experts to select. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + + Returns: + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + + Raises: + ValueError: If an unsupported scoring function is provided. + """ + + if scoring_func == "softmax": + topk_weights = router_logits.softmax(dim=-1) + elif scoring_func == "sigmoid": + topk_weights = router_logits.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if use_grouped_topk: + return _select_expert_use_group_topk( + topk_weights=topk_weights, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias) + + if custom_routing_function is not None: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts) + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + return topk_weights, topk_ids + + topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) + topk_weights = topk_weights.to(hidden_states.dtype) + + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + + return topk_weights, topk_ids diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 0b62fe1..ce238f4 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.fused_moe import select_experts +from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all, fused_experts_with_mc2) from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor @@ -245,34 +245,18 @@ class AscendW4A8DynamicFusedMoEMethod: 1] == global_num_experts, "Number of global experts mismatch" # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) fused_moe_state = get_forward_context().fused_moe_state shared_gate_up, shared_dequant_scale = None, None @@ -314,7 +298,7 @@ class AscendW4A8DynamicFusedMoEMethod: mc2_mask=kwargs.get("mc2_mask", None)) else: # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into fused_moe module. + # according to tp_size before they are feed into layers module. # Therefore, all2all is needed no matter how dp/tp is set so as to # dispatch/combine tokens. return fused_experts_with_all2all( diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index d3bff93..e4cbdc8 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p @@ -251,8 +252,7 @@ class AscendW8A8FusedMoEMethod: custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts, - ) + global_num_experts=global_num_experts) if is_310p(): return fused_experts_310p(hidden_states=x, @@ -645,123 +645,3 @@ def fused_experts( "currently does not support tensor parallelism") return final_hidden_states - - -def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts=-1, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Select top-k experts based on router logits. - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - router_logits: Router logits of shape (num_tokens, num_experts). - top_k: Number of experts to select. - use_grouped_topk: Whether to group experts before selecting top-k. - renormalize: Whether to renormalize the routing weights. - topk_group: Number of expert groups to select from. - num_expert_group: Number of experts in each group. - custom_routing_function: Custom routing function. - scoring_func: Scoring function to use. - e_score_correction_bias: Correction bias to apply to expert scores. - - Returns: - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - - Raises: - ValueError: If an unsupported scoring function is provided. - """ - - if scoring_func == "softmax": - # NOTE: vLLM use dtype=torch.float here - topk_weights = router_logits.softmax(dim=-1) - elif scoring_func == "sigmoid": - topk_weights = router_logits.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use biased - # scores for expert selection but original scores for routing weights - original_weights = topk_weights - topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) - - # TODO: Change to npu_group_topk when the latest CANN and NNAL is available - # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) - topk_weights = native_grouped_topk(topk_weights, num_expert_group, - topk_group) - # TODO bfloat16 is not supported in torch.topk with ge graph. - if e_score_correction_bias is not None: - topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_weights.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False) - elif custom_routing_function is None: - topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) - topk_weights = topk_weights.to(hidden_states.dtype) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - global_num_experts=global_num_experts, - ) - # Required by npu_moe_init_routing - topk_ids = topk_ids.to(torch.int32) - return topk_weights, topk_ids - - # Required by npu_moe_init_routing - topk_ids = topk_ids.to(torch.int32) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids - - -def native_grouped_topk( - topk_weights: torch.Tensor, - num_expert_group: Optional[int], - topk_group: Optional[int], -): - topk_group = 0 if topk_group is None else topk_group - num_expert_group = 0 if num_expert_group is None else num_expert_group - - num_token = topk_weights.shape[0] - grouped_weights = topk_weights.view(num_token, num_expert_group, - -1).max(dim=-1).values - topk_group_indices = torch.topk(grouped_weights.to(torch.float32), - k=topk_group, - dim=-1, - sorted=False)[1] - topk_group_mask = torch.zeros_like(grouped_weights) - topk_group_mask.scatter_(1, topk_group_indices, 1) - topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) - topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) - - return topk_weights diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 38aad66..21615f3 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -27,7 +27,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.fused_moe import select_experts +from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, dispose_tensor, get_ascend_soc_version) @@ -903,36 +903,18 @@ class AscendW8A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - is_deepseek_v3_r1 = global_num_experts == 256 - - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk当前写8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; 第三个输出是否输出 - # y2_flag=False, # old api; 第三个输出是否输出 - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) fused_moe_state = get_forward_context().fused_moe_state shared_gate_up, shared_dequant_scale = None, None @@ -995,7 +977,7 @@ class AscendW8A8DynamicFusedMoEMethod: expert_map=expert_map) else: # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into fused_moe module. + # according to tp_size before they are feed into layers module. # Therefore, all2all is needed no matter how dp/tp is set so as to # dispatch/combine tokens. return fused_experts_with_all2all(