refactor select_experts of moe module (#2150)
### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.
- vLLM version: v0.10.0
- vLLM main:
e18859298d
Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
@@ -5,12 +5,13 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
fused_experts, fused_experts_310p,
|
||||
native_grouped_topk,
|
||||
quant_per_tensor, select_experts)
|
||||
quant_per_tensor)
|
||||
|
||||
|
||||
class TestQuantPerTensor(TestBase):
|
||||
@@ -772,7 +773,7 @@ class TestSelectExperts(TestBase):
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
|
||||
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
@@ -868,9 +869,9 @@ class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
|
||||
with patch('torch.topk',
|
||||
return_value=(None, expected_topk_indices)) as mock_topk:
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=2)
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=2)
|
||||
|
||||
mock_topk.assert_called_once()
|
||||
|
||||
@@ -885,9 +886,9 @@ class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
expected_topk_indices = torch.tensor([[0], [1]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=1)
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=1)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
|
||||
@@ -900,7 +901,7 @@ class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
expected_topk_indices = torch.tensor([[0], [0]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=1,
|
||||
topk_group=1)
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=1,
|
||||
topk_group=1)
|
||||
self.assertTrue(result.numel() > 0)
|
||||
|
||||
Reference in New Issue
Block a user