From 58f9060efe26d4377af06dcb2e33778fb012e4f3 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 7 Jan 2025 19:47:37 +0800 Subject: [PATCH] Update int8 gemm config (#2774) --- .../src/sgl-kernel/csrc/int8_gemm_kernel.cu | 17 ++++++++++++++--- sgl-kernel/tests/test_int8_gemm.py | 4 ++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu index 7fb773e1a..cce32c2d8 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu @@ -88,10 +88,11 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); auto can_implement = gemm_op.can_implement(args); - TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + TORCH_CHECK(can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); auto status = gemm_op(args, workspace.data_ptr(), stream); - TORCH_CHECK(status == cutlass::Status::kSuccess) + TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); } template @@ -144,7 +145,17 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } - } else if (m <= 64 || (m <= 128 && n < 8192)) { + } else if (m <= 64) { + if (n <= 4096) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } + } else if (m <= 128 && n < 8192) { cutlass_int8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index 42f9dd229..34d17d1c7 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -37,8 +37,8 @@ class TestInt8Gemm(unittest.TestCase): print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") def test_accuracy(self): - Ms = [1, 128, 512, 1024, 4096] - Ns = [16, 128, 512, 1024, 4096] + Ms = [1, 128, 512, 1024, 4096, 8192] + Ns = [16, 128, 512, 1024, 4096, 8192, 16384] Ks = [512, 1024, 4096, 8192, 16384] bias_opts = [True, False] out_dtypes = [torch.float16, torch.bfloat16]