refactor select_experts of moe module (#2150)
### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.
- vLLM version: v0.10.0
- vLLM main:
e18859298d
Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
|
||||
|
||||
@@ -251,8 +252,7 @@ class AscendW8A8FusedMoEMethod:
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if is_310p():
|
||||
return fused_experts_310p(hidden_states=x,
|
||||
@@ -645,123 +645,3 @@ def fused_experts(
|
||||
"currently does not support tensor parallelism")
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts=-1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
router_logits: Router logits of shape (num_tokens, num_experts).
|
||||
top_k: Number of experts to select.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
|
||||
Returns:
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
|
||||
if scoring_func == "softmax":
|
||||
# NOTE: vLLM use dtype=torch.float here
|
||||
topk_weights = router_logits.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights = router_logits.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_weights = topk_weights
|
||||
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
):
|
||||
topk_group = 0 if topk_group is None else topk_group
|
||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||
|
||||
num_token = topk_weights.shape[0]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
||||
k=topk_group,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_group_mask = torch.zeros_like(grouped_weights)
|
||||
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
||||
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
||||
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
||||
|
||||
return topk_weights
|
||||
|
||||
Reference in New Issue
Block a user