diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh index 9bc45ab1c..92cd58fed 100644 --- a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh @@ -41,8 +41,8 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type using QuantType = cutlass::int4b_t; // 4-bit integer type using ElementAccumulator = float; // Accumulator type using ElementScale = cutlass::bfloat16_t; // Scale type -using ElementC = cutlass::half_t; // Default output type (FP16) -using ElementD = ElementC; // Default output type (FP16) +using ElementC = cutlass::bfloat16_t; // Output type +using ElementD = ElementC; // Output type using ProblemShape = cutlass::gemm::GroupProblemShape>; // Architecture-specific configurations diff --git a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py index 4ad5d29f5..b0e209494 100644 --- a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py +++ b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py @@ -96,7 +96,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device) # Create output tensor - c = torch.empty((m, n), dtype=torch.float16, device=device) + c = torch.empty((m, n), dtype=torch.bfloat16, device=device) cutlass_w4a8_moe_mm( c, a_q, @@ -211,7 +211,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): b_strides = a_strides s_strides = c_strides - c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device) + c_perm = torch.empty((batch_size, n), dtype=torch.bfloat16, device=device) cutlass_w4a8_moe_mm( c_perm, a_q_perm, @@ -262,10 +262,9 @@ def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_r continue a = a_q[token_idx] - ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float) - ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype) - c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale - c = c.to(dtype) + ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(torch.float32) + ref_w = w[i].to(torch.float32) * ref_w_scale_repeat + c = torch.matmul(a.to(torch.float32), ref_w.t()) * a_scale c_ref[token_idx] = c.to(dtype) return c_ref