[Qwen-moe] Remove the minor operation arange (#2373)
### What this PR does / why we need it?
Integrate the arange operator to reduce the time spent and improve
performance
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
56dcf4e7e9
---------
Signed-off-by: s30076806 <songjiayang2@h-partners.com>
This commit is contained in:
@@ -92,8 +92,15 @@ def test_fused_experts(
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
row_idx = (torch.arange(
|
||||
0,
|
||||
m * topk,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).view(topk, -1).permute(1, 0).contiguous())
|
||||
|
||||
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk,
|
||||
e_map)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
||||
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
||||
@@ -148,7 +155,7 @@ def test_select_experts(
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=topk,
|
||||
@@ -169,6 +176,7 @@ def test_select_experts(
|
||||
assert topk_weights.shape == (m, topk)
|
||||
assert topk_ids.shape == (m, topk)
|
||||
assert topk_ids.dtype == torch.int32
|
||||
assert row_idx.shape == (m, topk)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
|
||||
Reference in New Issue
Block a user