refactor select_experts of moe module (#2150)
### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.
- vLLM version: v0.10.0
- vLLM main:
e18859298d
Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
@@ -27,7 +27,7 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version)
|
||||
@@ -903,36 +903,18 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk当前写8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; 第三个输出是否输出
|
||||
# y2_flag=False, # old api; 第三个输出是否输出
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
@@ -995,7 +977,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
expert_map=expert_map)
|
||||
else:
|
||||
# The current implementation of deepseek moe splits hidden_states
|
||||
# according to tp_size before they are feed into fused_moe module.
|
||||
# according to tp_size before they are feed into layers module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return fused_experts_with_all2all(
|
||||
|
||||
Reference in New Issue
Block a user