Files
sglang/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu

210 lines
12 KiB
Plaintext
Raw Normal View History

2025-01-06 22:51:22 +08:00
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
#include "utils.hpp"
template <typename ElementOutput, typename ArchTag, typename ThreadblockShape, typename WarpShape,
typename InstructionShape, int NumStages>
void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementAccumulator = int32_t;
using ElementCompute = float;
using ElementInputA = int8_t;
using ElementInputB = int8_t;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA,
ElementInputB, ElementOutput, ElementCompute>;
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, ElementInputB,
cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
ThreadblockSwizzle, NumStages, true, typename DefaultGemmConf::Operator>::GemmKernel;
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits<ElementOutput>::value>,
ElementCompute>;
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
ThreadblockShape, GemmKernel_::kThreadCount, AlphaColTileIterator,
typename GemmKernel_::Epilogue::OutputTileIterator, ElementAccumulator, ElementCompute, EpilogueOutputOp>;
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
using GemmKernel =
cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
Gemm gemm_op;
int m = mat_a.size(0);
int k = mat_a.size(1);
int n = mat_b.size(1);
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
int64_t lda = mat_a.stride(0);
int64_t ldb = mat_b.stride(1);
int64_t ldd = out.stride(0);
ElementOutput* bias_ptr = nullptr;
int64_t ldc = 0;
if (bias) {
bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
}
typename EpilogueOutputOp::Params linearScalingParams;
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
typename Gemm::Arguments args{{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0},
{a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
auto workspace = torch::empty(gemm_op.get_workspace_size(args),
torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess)
}
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm75_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0);
if (m <= 32) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
scales_b, bias);
} else if (m <= 64) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
scales_b, bias);
} else if (m <= 256) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
scales_b, bias);
} else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
scales_b, bias);
}
}
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0);
int n = mat_b.size(1);
if (m <= 16) {
if (n <= 4096) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a,
scales_b, bias);
} else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
}
} else if (m <= 32) {
if (n <= 4096) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a,
scales_b, bias);
} else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
}
} else if (m <= 64 || (m <= 128 && n < 8192)) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
} else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
}
}
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment");
TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment");
TORCH_CHECK(mat_b.size(1) % 8 == 0, "mat_b.size(1) must be multiple of 8 for memory alignment"); // out.stride(0)
TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8");
TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
if (bias) {
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
}
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
auto sm_version = getSMVersion();
if (sm_version >= 75 && sm_version < 80) {
TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75");
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (sm_version >= 80 && sm_version <= 90) {
if (out_dtype == torch::kBFloat16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else {
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability.");
}
return out;
}