[optimize] fuse renormalize into moe_topk_softmax (#7744)

Co-authored-by: ispobock <ispobaoke@gmail.com>
This commit is contained in:
Yi Zhang
2025-07-04 03:42:44 +08:00
committed by GitHub
parent 6840a7bbb2
commit 2998c4bdf4
7 changed files with 254 additions and 101 deletions

View File

@@ -22,14 +22,10 @@ def test_topk_softmax(num_tokens, num_experts, topk):
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
token_expert_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device="cuda"
)
topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
@@ -47,5 +43,97 @@ def test_topk_softmax(num_tokens, num_experts, topk):
), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}"
@pytest.mark.parametrize(
"num_tokens, num_experts, topk, dtype",
list(
itertools.product(
[1, 16, 128, 512, 1024, 2048], # num_tokens
[4, 8, 16, 32, 64, 128, 256], # num_experts
[1, 2, 4], # topk
[torch.float16, torch.bfloat16, torch.float32], # dtype
)
),
)
def test_topk_softmax_dtype_regression(num_tokens, num_experts, topk, dtype):
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights,
topk_indices,
gating_output,
)
topk_weights_ref = torch.empty(
(num_tokens, topk), dtype=torch.float32, device="cuda"
)
topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights_ref,
topk_indices_ref,
gating_output.float(),
)
assert torch.allclose(
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: SGLang old interface={topk_indices_ref} vs SGLang new interface={topk_weights}"
assert torch.allclose(
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: SGLang old interface={topk_indices_ref}, SGLang new interface={topk_indices}"
@pytest.mark.parametrize(
"num_tokens, num_experts, topk",
list(
itertools.product(
[1, 16, 128, 512, 1024, 2048], # num_tokens
[4, 8, 16, 32, 64, 128, 256], # num_experts
[1, 2, 4], # topk
)
),
)
def test_topk_softmax_renormalize(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), dtype=torch.bfloat16, device="cuda"
)
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights,
topk_indices,
gating_output,
renormalize=True,
)
topk_weights_ref = torch.empty(
(num_tokens, topk), dtype=torch.float32, device="cuda"
)
topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
token_expert_indices_ref = torch.empty(
(num_tokens, topk), dtype=torch.int32, device="cuda"
)
topk_softmax(
topk_weights_ref,
topk_indices_ref,
gating_output,
)
topk_weights_ref = topk_weights_ref / topk_weights_ref.sum(dim=-1, keepdim=True)
assert torch.allclose(
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: SGLang w/o fused renormalize={topk_indices_ref} vs SGLang w/ fused renormalize={topk_weights}"
assert torch.allclose(
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: SGLang w/o fused renormalize={topk_indices_ref}, SGLang w/ fused renormalize={topk_indices}"
if __name__ == "__main__":
pytest.main([__file__])