[Perf]Enable npu_moe_gating_top_k_softmax on quantized scenarios (#2633)

### What this PR does / why we need it?
This PR enables `npu_moe_gating_top_k_softmax` when running quantized
MoE (such as W8A8). This op in fact makes no distinction between
quantized and non-quantized scenarios. Introducing this op reduces 3~4ms
for TPOT.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
ce30dca5c4

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2025-09-03 09:14:17 +08:00
committed by GitHub
parent 24d4dad7b2
commit b84465c525
3 changed files with 33 additions and 15 deletions

View File

@@ -716,8 +716,19 @@ class TestSelectExperts(TestBase):
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
def test_softmax_scoring(self):
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
@@ -816,13 +827,19 @@ class TestSelectExperts(TestBase):
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)
@patch('torch.topk')
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_renormalize(self, mock_topk):
"""Test weight renormalization"""
"""Test renormalization"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long))
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
@@ -836,13 +853,19 @@ class TestSelectExperts(TestBase):
sums = weights.sum(dim=-1)
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
@patch('torch.topk')
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_output_dtypes(self, mock_topk):
"""Test output dtypes"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long))
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,