use npu_moe_gating_top_k_softmax (#1355)
### What this PR does / why we need it?
The optimization solution for non-deepseek select_experts is to replace
gating_topk_softmax with softmax+topk+to, which is optimized from 37us
to 14us on bf16/fp16 of qwen3-235b
- vLLM version: v0.9.2
- vLLM main:
1a4f35e2ea
---------
Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -49,6 +49,7 @@ from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
||||
|
||||
|
||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||
@@ -821,6 +822,39 @@ def fused_experts(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def select_gating_top_k_softmax_experts(
|
||||
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
|
||||
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
only supports float16、bfloat16、float32
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
router_logits: Router logits of shape (num_tokens, num_experts).
|
||||
top_k: Number of experts to select.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
|
||||
Returns:
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
router_logits, None, k=top_k)
|
||||
|
||||
# # Required by npu_moe_init_routing
|
||||
# topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
# topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
@@ -1013,6 +1047,12 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
# y2_flag=False, # old api; 第三个输出是否输出
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
||||
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
Reference in New Issue
Block a user