【fix】ops gatingtopk fix nightly ci error (#4340)
### What this PR does / why we need it? This pr https://github.com/vllm-project/vllm-ascend/pull/2958 is supporting gatingtopk operator generalization, but caused nightly ci error. Now we add check logits for ops gatingtopk, and fix nightly ci. - vLLM version: v0.12.0 Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
@@ -28,7 +28,8 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import (
|
||||||
|
check_npu_moe_gating_top_k, select_experts)
|
||||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import \
|
from vllm_ascend.ops.fused_moe.token_dispatcher import \
|
||||||
TokenDispatcherWithAllGather
|
TokenDispatcherWithAllGather
|
||||||
@@ -303,7 +304,10 @@ def test_select_experts(
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_grouped_topk:
|
call_moe_gatingtopk = check_npu_moe_gating_top_k(
|
||||||
|
hidden_states, topk, topk_group, num_expert_group, scoring_func,
|
||||||
|
custom_routing_function)
|
||||||
|
if not call_moe_gatingtopk and use_grouped_topk:
|
||||||
mock_native_grouped_topk.assert_called_once()
|
mock_native_grouped_topk.assert_called_once()
|
||||||
else:
|
else:
|
||||||
mock_native_grouped_topk.assert_not_called()
|
mock_native_grouped_topk.assert_not_called()
|
||||||
|
|||||||
@@ -823,8 +823,7 @@ class TestSelectExperts(TestBase):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
use_grouped_topk=False,
|
use_grouped_topk=False,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
scoring_func="invalid_func",
|
scoring_func="invalid_func")
|
||||||
custom_routing_function=self.mock_custom_routing)
|
|
||||||
|
|
||||||
@patch('torch.topk')
|
@patch('torch.topk')
|
||||||
def test_grouped_topk(self, mock_topk):
|
def test_grouped_topk(self, mock_topk):
|
||||||
@@ -834,15 +833,13 @@ class TestSelectExperts(TestBase):
|
|||||||
self.top_k,
|
self.top_k,
|
||||||
dtype=torch.long))
|
dtype=torch.long))
|
||||||
|
|
||||||
weights, ids = select_experts(
|
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||||
hidden_states=self.hidden_states,
|
|
||||||
router_logits=self.router_logits,
|
router_logits=self.router_logits,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
use_grouped_topk=True,
|
use_grouped_topk=True,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
topk_group=4,
|
topk_group=4,
|
||||||
num_expert_group=2,
|
num_expert_group=2)
|
||||||
custom_routing_function=self.mock_custom_routing)
|
|
||||||
|
|
||||||
mock_topk.assert_called()
|
mock_topk.assert_called()
|
||||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||||
@@ -864,8 +861,7 @@ class TestSelectExperts(TestBase):
|
|||||||
renormalize=False,
|
renormalize=False,
|
||||||
topk_group=4,
|
topk_group=4,
|
||||||
num_expert_group=2,
|
num_expert_group=2,
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
custom_routing_function=self.mock_custom_routing)
|
|
||||||
|
|
||||||
mock_grouped_topk.assert_called_once()
|
mock_grouped_topk.assert_called_once()
|
||||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||||
|
|||||||
@@ -60,7 +60,15 @@ def select_experts(hidden_states: torch.Tensor,
|
|||||||
if weight_prefetch_method:
|
if weight_prefetch_method:
|
||||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||||
hidden_states, "gate_up")
|
hidden_states, "gate_up")
|
||||||
if custom_routing_function is None:
|
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
top_k=top_k,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
custom_routing_function=custom_routing_function)
|
||||||
|
|
||||||
|
if is_support_npu_moe_gating_top_k:
|
||||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -90,6 +98,32 @@ def select_experts(hidden_states: torch.Tensor,
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def check_npu_moe_gating_top_k(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
custom_routing_function: Optional[Callable] = None):
|
||||||
|
if custom_routing_function is not None:
|
||||||
|
return False
|
||||||
|
if scoring_func != "softmax" and scoring_func != "sigmoid":
|
||||||
|
return False
|
||||||
|
topk_group = topk_group if topk_group is not None else 1
|
||||||
|
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
||||||
|
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
|
||||||
|
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
|
||||||
|
return False
|
||||||
|
if topk_group < 1 or topk_group > num_expert_group:
|
||||||
|
return False
|
||||||
|
if top_k < 1 or \
|
||||||
|
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
|
||||||
|
return False
|
||||||
|
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _native_grouped_topk(
|
def _native_grouped_topk(
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_expert_group: Optional[int],
|
num_expert_group: Optional[int],
|
||||||
@@ -172,12 +206,9 @@ def _select_experts_with_fusion_ops(
|
|||||||
routed_scaling_factor=1.0,
|
routed_scaling_factor=1.0,
|
||||||
global_num_experts: int = -1):
|
global_num_experts: int = -1):
|
||||||
|
|
||||||
if scoring_func == "softmax":
|
topk_group = topk_group if topk_group is not None else 1
|
||||||
norm_type = 0
|
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
||||||
topk_group = 1
|
norm_type = 0 if scoring_func == "softmax" else 1
|
||||||
num_expert_group = 1
|
|
||||||
else:
|
|
||||||
norm_type = 1
|
|
||||||
if e_score_correction_bias is not None and \
|
if e_score_correction_bias is not None and \
|
||||||
e_score_correction_bias.dtype != router_logits.dtype:
|
e_score_correction_bias.dtype != router_logits.dtype:
|
||||||
e_score_correction_bias = e_score_correction_bias.to(
|
e_score_correction_bias = e_score_correction_bias.to(
|
||||||
|
|||||||
Reference in New Issue
Block a user