【fix】ops gatingtopk fix nightly ci error (#4340)
### What this PR does / why we need it? This pr https://github.com/vllm-project/vllm-ascend/pull/2958 is supporting gatingtopk operator generalization, but caused nightly ci error. Now we add check logits for ops gatingtopk, and fix nightly ci. - vLLM version: v0.12.0 Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
@@ -823,8 +823,7 @@ class TestSelectExperts(TestBase):
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid_func",
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
scoring_func="invalid_func")
|
||||
|
||||
@patch('torch.topk')
|
||||
def test_grouped_topk(self, mock_topk):
|
||||
@@ -834,15 +833,13 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2)
|
||||
|
||||
mock_topk.assert_called()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
@@ -864,8 +861,7 @@ class TestSelectExperts(TestBase):
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
mock_grouped_topk.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
Reference in New Issue
Block a user