refactor select_experts of moe module (#2150)

### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.

- vLLM version: v0.10.0
- vLLM main:
e18859298d

Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
shiyuan680
2025-08-14 11:50:53 +08:00
committed by GitHub
parent 103654ccd6
commit e14f2ef669
10 changed files with 359 additions and 370 deletions

View File

@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts,
unified_fused_experts)
from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import is_310p
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -59,7 +59,7 @@ def forward_oot(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
@@ -69,7 +69,6 @@ def forward_oot(
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@@ -80,7 +79,7 @@ def forward_oot(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None