refactor select_experts of moe module (#2150)

### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.

- vLLM version: v0.10.0
- vLLM main:
e18859298d

Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
shiyuan680
2025-08-14 11:50:53 +08:00
committed by GitHub
parent 103654ccd6
commit e14f2ef669
10 changed files with 359 additions and 370 deletions

View File

@@ -5,12 +5,13 @@ import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
select_experts)
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod,
fused_experts, fused_experts_310p,
native_grouped_topk,
quant_per_tensor, select_experts)
quant_per_tensor)
class TestQuantPerTensor(TestBase):
@@ -772,7 +773,7 @@ class TestSelectExperts(TestBase):
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)
@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
"""Test grouped topk with expert score correction bias"""
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
@@ -868,9 +869,9 @@ class TestNativeGroupedTopkPartialMock(TestBase):
with patch('torch.topk',
return_value=(None, expected_topk_indices)) as mock_topk:
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=2)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=2)
mock_topk.assert_called_once()
@@ -885,9 +886,9 @@ class TestNativeGroupedTopkPartialMock(TestBase):
expected_topk_indices = torch.tensor([[0], [1]])
with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=1)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=1)
expected_result = torch.tensor(
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
@@ -900,7 +901,7 @@ class TestNativeGroupedTopkPartialMock(TestBase):
expected_topk_indices = torch.tensor([[0], [0]])
with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=1,
topk_group=1)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=1,
topk_group=1)
self.assertTrue(result.numel() > 0)