diff --git a/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu index f18c81865..b47904cb1 100644 --- a/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu @@ -409,8 +409,8 @@ void sm89_dispatch_shape( cutlass_int8_scaled_mm< ElementOutput, ArchTag, - cutlass::gemm::GemmShape<32, 64, 128>, - cutlass::gemm::GemmShape<16, 64, 64>, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); }