refactor select_experts of moe module (#2150)

### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.

- vLLM version: v0.10.0
- vLLM main:
e18859298d

Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
shiyuan680
2025-08-14 11:50:53 +08:00
committed by GitHub
parent 103654ccd6
commit e14f2ef669
10 changed files with 359 additions and 370 deletions

View File

@@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import AttentionType
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
@@ -251,8 +252,7 @@ class AscendW8A8FusedMoEMethod:
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
global_num_experts=global_num_experts)
if is_310p():
return fused_experts_310p(hidden_states=x,
@@ -645,123 +645,3 @@ def fused_experts(
"currently does not support tensor parallelism")
return final_hidden_states
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts=-1,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if scoring_func == "softmax":
# NOTE: vLLM use dtype=torch.float here
topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_weights = topk_weights
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
elif custom_routing_function is None:
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights