【fix】ops gatingtopk fix nightly ci error (#4340)

### What this PR does / why we need it?
This pr https://github.com/vllm-project/vllm-ascend/pull/2958 is
supporting gatingtopk operator generalization, but caused nightly ci
error.
Now we add check logits for ops gatingtopk, and fix nightly ci.

- vLLM version: v0.12.0

Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
1092626063
2025-12-04 20:09:21 +08:00
committed by GitHub
parent da84eb2f40
commit b3e1377a92
3 changed files with 53 additions and 22 deletions

View File

@@ -823,8 +823,7 @@ class TestSelectExperts(TestBase):
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid_func",
custom_routing_function=self.mock_custom_routing)
scoring_func="invalid_func")
@patch('torch.topk')
def test_grouped_topk(self, mock_topk):
@@ -834,15 +833,13 @@ class TestSelectExperts(TestBase):
self.top_k,
dtype=torch.long))
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2,
custom_routing_function=self.mock_custom_routing)
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -864,8 +861,7 @@ class TestSelectExperts(TestBase):
renormalize=False,
topk_group=4,
num_expert_group=2,
e_score_correction_bias=e_score_correction_bias,
custom_routing_function=self.mock_custom_routing)
e_score_correction_bias=e_score_correction_bias)
mock_grouped_topk.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))