From 0ac019f17189e2ba3a3bab047cf441e060d339a1 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 21 Jan 2025 22:21:54 +0800 Subject: [PATCH] Support sm90 Int8 gemm (#3035) --- .../src/sgl-kernel/csrc/int8_gemm_kernel.cu | 210 +++++++++++++++++- sgl-kernel/tests/test_int8_gemm.py | 2 +- 2 files changed, 210 insertions(+), 2 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu index cce32c2d8..8e3f72757 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu @@ -3,13 +3,23 @@ #include #include #include +#include #include +#include +#include +#include +#include +#include +#include + #include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" #include "cutlass_extensions/gemm/gemm_universal_base_compat.h" #include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h" #include "utils.hpp" +using namespace cute; + template void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, @@ -166,6 +176,186 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t } } +template +void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ArchTag = cutlass::arch::Sm90; + + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + using TileSchedulerType = cutlass::gemm::PersistentScheduler; + + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, + Stride, Int<0>, Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, + Stride, Int<1>, Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + Stride, Int<1>, Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + // Scale + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput, + cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; + + using Stages = cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB, + cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages, + MainloopScheduleType>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + Gemm gemm_op; + + int m = mat_a.size(0); + int k = mat_a.size(1); + int n = mat_b.size(1); + + auto a_ptr = static_cast(mat_a.data_ptr()); + auto b_ptr = static_cast(mat_b.data_ptr()); + auto o_ptr = static_cast(out.data_ptr()); + + auto a_s_ptr = static_cast(scales_a.data_ptr()); + auto b_s_ptr = static_cast(scales_b.data_ptr()); + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {a_ptr, stride_a, b_ptr, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + o_ptr, + stride_d}}; + + if constexpr (WithBias) { + ElementOutput* bias_ptr = static_cast(bias->data_ptr()); + args.epilogue.thread = { + {a_s_ptr}, + {{b_s_ptr}, {}, {}}, + {bias_ptr}, + {}, + }; + } else { + args.epilogue.thread = { + {a_s_ptr}, + {{b_s_ptr}, {}, {}}, + {}, + }; + } + + auto workspace = torch::empty(gemm_op.get_workspace_size(args), + torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + + auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); +} + +template +void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + if (bias) { + cutlass_int8_scaled_mm_sm90( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm_sm90( + out, mat_a, mat_b, scales_a, scales_b, bias); + } +} + +template +void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + int m = mat_a.size(0); + int n = mat_b.size(1); + if (m <= 32) { + if (n < 8192) { + return sm90_dispatch_bias, Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + if (n < 8192) { + return sm90_dispatch_bias, Shape<_1, _4, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_1, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + if (n <= 4096) { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, + bias); + } +} + torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias) { @@ -204,7 +394,7 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75"); sm75_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); - } else if (sm_version >= 80 && sm_version <= 90) { + } else if (sm_version >= 80 && sm_version < 90) { if (out_dtype == torch::kBFloat16) { sm80_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); @@ -212,6 +402,24 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma sm80_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); } + } else if (sm_version == 90) { +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + // cutlass 3.x + if (out_dtype == torch::kBFloat16) { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } +#else + // fallback to cutlass 2.x + if (out_dtype == torch::kBFloat16) { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } +#endif } else { TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability."); } diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index 34d17d1c7..c33a3effc 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -25,7 +25,7 @@ class TestInt8Gemm(unittest.TestCase): scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) if with_bias: - bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10 + bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 else: bias = None