diff --git a/python/sglang/test/test_fp4_moe.py b/python/sglang/test/test_fp4_moe.py index 30b1fe9db..bf2308a8f 100644 --- a/python/sglang/test/test_fp4_moe.py +++ b/python/sglang/test/test_fp4_moe.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest import torch +from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe from sgl_kernel import scaled_fp4_quant from sglang.srt.layers.activation import SiluAndMul @@ -111,15 +114,16 @@ def torch_moe(a, w1, w2, score, topk, expert_map): ).sum(dim=1) -@pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("e", [40, 64, 256]) -@pytest.mark.parametrize("topk", [1, 6, 8]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -@torch.inference_mode() -def test_cutlass_fp4_moe_no_graph( - m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +def check_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + moe_impl: Callable, + flip_w13: bool, ): - torch.manual_seed(7) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -167,38 +171,18 @@ def test_cutlass_fp4_moe_no_graph( a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) - # strides for the cutlass moe_fp4 kernel - ab_strides_13 = torch.full( - (e,), w1_q.shape[2] * 2, dtype=torch.int64, device=w1_q.device - ) - c_strides_13 = torch.full( - (e,), w1_q.shape[1], dtype=torch.int64, device=w1_q.device - ) - ab_strides_2 = torch.full( - (e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device - ) - c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device) - params = CutlassMoEParams( - CutlassMoEType.BlockscaledFP4, - device=a.device, - num_experts=e, - intermediate_size_per_partition=n, # n - hidden_size=k, - ) # k - cutlass_output = cutlass_moe_fp4( + test_output = moe_impl( a=a, - a1_gscale=a1_gs, - w1_fp4=w1_q, - w1_blockscale=w1_blockscale, - w1_alphas=(1 / w1_gs), - a2_gscale=a2_gs, - w2_fp4=w2_q, - w2_blockscale=w2_blockscale, - w2_alphas=(1 / w2_gs), topk_weights=topk_weights, topk_ids=topk_ids, - params=params, - apply_router_weight_on_input=False, + w1_q=w1_q, + w2_q=w2_q, + a1_gs=a1_gs, + w1_blockscale=w1_blockscale, + w1_alphas=(1 / w1_gs), + a2_gs=a2_gs, + w2_blockscale=w2_blockscale, + w2_alphas=(1 / w2_gs), ) # Reference check: @@ -237,10 +221,108 @@ def test_cutlass_fp4_moe_no_graph( block_size=quant_blocksize, ) + if flip_w13: + dim = -2 + size = w1_d.size(dim) + assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" + half = size // 2 + # Reorder weight + w1, w3 = w1_d.split(half, dim=dim) + w1_d = torch.cat([w3, w1], dim=dim).contiguous() + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) - torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(torch_output, test_output, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_cutlass_fp4_moe_no_graph( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +): + def cutlass_moe_impl( + a, + topk_weights, + topk_ids, + w1_q, + w2_q, + a1_gs, + w1_blockscale, + w1_alphas, + a2_gs, + w2_blockscale, + w2_alphas, + ): + params = CutlassMoEParams( + CutlassMoEType.BlockscaledFP4, + device=a.device, + num_experts=e, + intermediate_size_per_partition=n, # n + hidden_size=k, + ) # k + return cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_q, + w1_blockscale=w1_blockscale, + w1_alphas=w1_alphas, + a2_gscale=a2_gs, + w2_fp4=w2_q, + w2_blockscale=w2_blockscale, + w2_alphas=w2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + params=params, + apply_router_weight_on_input=False, + ) + + check_moe(m, n, k, e, topk, dtype, cutlass_moe_impl, flip_w13=False) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_flashinfer_fp4_moe_no_graph( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +): + def flashinfer_moe_impl( + a, + topk_weights, + topk_ids, + w1_q, + w2_q, + a1_gs, + w1_blockscale, + w1_alphas, + a2_gs, + w2_blockscale, + w2_alphas, + ): + return flashinfer_cutlass_fused_moe( + a, + topk_ids.to(torch.int), + topk_weights, + w1_q.view(torch.long), + w2_q.view(torch.long), + a.dtype, + quant_scales=[ + a1_gs, + w1_blockscale.view(torch.int32), + w1_alphas, + a2_gs, + w2_blockscale.view(torch.int32), + w2_alphas, + ], + )[0] + + check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True) if __name__ == "__main__": test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half) + test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)