refactor select_experts of moe module (#2150)
### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.
- vLLM version: v0.10.0
- vLLM main:
e18859298d
Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -389,3 +390,28 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
assert result.shape == (16, 2)
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
|
||||
|
||||
class TestExpertsSelector:
|
||||
|
||||
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
||||
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
||||
global_num_experts):
|
||||
|
||||
x = torch.randn(8, 2)
|
||||
router_logits = torch.randn(8, 2)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
topk_group=None,
|
||||
num_expert_group=None,
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
e_score_correction_bias=None,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
assert topk_weights.shape == (8, 2)
|
||||
assert topk_ids.shape == (8, 2)
|
||||
|
||||
Reference in New Issue
Block a user