diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 8e4143e0e..216424eea 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -147,8 +147,8 @@ def cutlass_w4a8_moe( k, ) - c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half) - c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half) + c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16) + c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16) cutlass_w4a8_moe_mm( c1, @@ -166,7 +166,7 @@ def cutlass_w4a8_moe( topk, ) - intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half) + intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16) silu_and_mul(c1, intermediate) intermediate_q = torch.empty( diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh index 92cd58fed..d8b794997 100644 --- a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh @@ -209,7 +209,7 @@ void cutlass_w4a8_group_gemm_caller( Args arguments; decltype(arguments.epilogue.thread) fusion_args; - fusion_args.alpha = 1.0f; + fusion_args.alpha = 0; fusion_args.beta = 0; fusion_args.alpha_ptr = a_scales.data_ptr(); ; diff --git a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py index f51d16b5a..3f9e60077 100644 --- a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py +++ b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py @@ -1,6 +1,6 @@ import pytest import torch -from sgl_kernel import cutlass_w4a8_moe_mm +from sgl_kernel import cutlass_w4a8_moe_mm, sgl_per_tensor_quant_fp8 from utils import is_hopper @@ -67,7 +67,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): if debug: a = torch.ones(m, k, dtype=torch.bfloat16, device=device) ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device) - a_scale = torch.ones(1, dtype=torch.float, device=device) ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device) else: a = torch.randn(m, k, dtype=dtype, device=device) @@ -75,7 +74,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): -8, 8, (num_experts, n, k), dtype=torch.int8, device=device ) affine_coeff = 0.005 - a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02 ref_w_scale = ( torch.randn(num_experts, n, k // 128, dtype=dtype, device=device) * affine_coeff @@ -93,7 +91,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): s_strides = c_strides # Quantize input - a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device) + a_q, a_scale = _per_tensor_quant_fp8(a) # Create output tensor c = torch.empty((m, n), dtype=torch.bfloat16, device=device) @@ -117,7 +115,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): # Reference implementation experts_selection_result = torch.full((m,), 0) c_ref = ref_grouped_gemm( - c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result + c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result ) # Compare results @@ -138,17 +136,29 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): raise -# @pytest.mark.skipif( -# not is_hopper(), -# reason="cutlass_w4a8_moe_mm is only supported on sm90", -# ) +def _per_tensor_quant_fp8( + x: torch.Tensor, + dtype: torch.dtype = torch.float8_e4m3fn, +): + assert x.is_contiguous(), "`x` is not contiguous" + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_s = torch.empty( + 1, + device=x.device, + dtype=torch.float32, + ) + sgl_per_tensor_quant_fp8(x, x_q, x_s, is_static=False) + return x_q, x_s + + @pytest.mark.skipif( - True, - reason="TODO(rainj-me): fix cu129 binary issue on hopper cu126", + not is_hopper(), + reason="cutlass_w4a8_moe_mm is only supported on sm90", ) -@pytest.mark.parametrize("batch_size", [2, 4, 8, 16]) -@pytest.mark.parametrize("k", [256, 512, 1024]) -@pytest.mark.parametrize("n", [1024, 2048, 7168]) +@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32]) +@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168]) +@pytest.mark.parametrize("n", [256, 512, 1024, 2048]) @pytest.mark.parametrize("num_experts", [2, 4, 6, 8]) def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): torch.manual_seed(0) @@ -163,7 +173,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): if debug: a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device) ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device) - a_scale = torch.ones(1, dtype=torch.float, device=device) ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device) else: a = torch.randn(batch_size, k, dtype=dtype, device=device) @@ -171,7 +180,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): -8, 8, (num_experts, n, k), dtype=torch.int8, device=device ) affine_coeff = 0.005 - a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02 ref_w_scale = ( torch.randn(num_experts, n, k // 128, dtype=dtype, device=device) * affine_coeff @@ -202,12 +210,8 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device) # Permute input and quantize - a_perm = a[permutation] - a_q_perm = ( - torch.clamp((a_perm / a_scale), -448.0, 448.0) - .to(torch.float8_e4m3fn) - .to(device) - ) + a_q, a_scale = _per_tensor_quant_fp8(a) + a_q_perm = a_q[permutation] # Create stride tensors a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64) @@ -238,7 +242,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): c = c.to(dtype) c_ref = ref_grouped_gemm( - c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result + c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result ) # Compare results @@ -256,10 +260,11 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): raise -def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result): +def ref_grouped_gemm( + c, a_q, a_scale, w, w_scale, num_experts, experts_selection_result +): dtype = torch.bfloat16 c_ref = torch.zeros_like(c) - a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn) for i in range(num_experts): token_idx = torch.where(experts_selection_result == i)[0] if len(token_idx) == 0: