[Qwen-moe] Remove the minor operation arange (#2373)

### What this PR does / why we need it?
Integrate the arange operator to reduce the time spent and improve
performance

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

- vLLM version: v0.10.1.1
- vLLM main:
56dcf4e7e9

---------

Signed-off-by: s30076806 <songjiayang2@h-partners.com>
This commit is contained in:
s30076806
2025-08-27 09:13:31 +08:00
committed by GitHub
parent 358ba68994
commit 6a4ec186e7
9 changed files with 80 additions and 79 deletions

View File

@@ -92,8 +92,15 @@ def test_fused_experts(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk,
e_map)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
@@ -148,7 +155,7 @@ def test_select_experts(
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
@@ -169,6 +176,7 @@ def test_select_experts(
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)
@pytest.mark.parametrize("device", DEVICE)