refactor select_experts of moe module (#2150)
### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.
- vLLM version: v0.10.0
- vLLM main:
e18859298d
Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
@@ -26,7 +26,8 @@ import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
|
||||
from vllm_ascend.ops.fused_moe import fused_experts
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
@@ -142,7 +143,7 @@ def test_select_experts(
|
||||
dtype=torch.int32)
|
||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
|
||||
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
|
||||
) as mock_native_grouped_topk:
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
|
||||
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -389,3 +390,28 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
assert result.shape == (16, 2)
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
|
||||
|
||||
class TestExpertsSelector:
|
||||
|
||||
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
||||
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
||||
global_num_experts):
|
||||
|
||||
x = torch.randn(8, 2)
|
||||
router_logits = torch.randn(8, 2)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
topk_group=None,
|
||||
num_expert_group=None,
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
e_score_correction_bias=None,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
assert topk_weights.shape == (8, 2)
|
||||
assert topk_ids.shape == (8, 2)
|
||||
|
||||
@@ -5,12 +5,13 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
fused_experts, fused_experts_310p,
|
||||
native_grouped_topk,
|
||||
quant_per_tensor, select_experts)
|
||||
quant_per_tensor)
|
||||
|
||||
|
||||
class TestQuantPerTensor(TestBase):
|
||||
@@ -772,7 +773,7 @@ class TestSelectExperts(TestBase):
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
|
||||
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
@@ -868,9 +869,9 @@ class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
|
||||
with patch('torch.topk',
|
||||
return_value=(None, expected_topk_indices)) as mock_topk:
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=2)
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=2)
|
||||
|
||||
mock_topk.assert_called_once()
|
||||
|
||||
@@ -885,9 +886,9 @@ class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
expected_topk_indices = torch.tensor([[0], [1]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=1)
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=1)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
|
||||
@@ -900,7 +901,7 @@ class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
expected_topk_indices = torch.tensor([[0], [0]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=1,
|
||||
topk_group=1)
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=1,
|
||||
topk_group=1)
|
||||
self.assertTrue(result.numel() > 0)
|
||||
|
||||
Reference in New Issue
Block a user