Support cutlass Int8 gemm (#2752)
This commit is contained in:
@@ -31,6 +31,7 @@ add_library(_kernels SHARED
|
||||
src/sgl-kernel/csrc/trt_reduce_internal.cu
|
||||
src/sgl-kernel/csrc/trt_reduce_kernel.cu
|
||||
src/sgl-kernel/csrc/moe_align_kernel.cu
|
||||
src/sgl-kernel/csrc/int8_gemm_kernel.cu
|
||||
src/sgl-kernel/csrc/sgl_kernel_ops.cu
|
||||
)
|
||||
|
||||
|
||||
55
sgl-kernel/benchmark/bench_int8_gemm.py
Normal file
55
sgl-kernel/benchmark/bench_int8_gemm.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel"],
|
||||
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
ylabel="GB/s",
|
||||
plot_name="int8 scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
M, N, K = batch_size, 4096, 8192
|
||||
a = to_int8(torch.randn((M, K), device="cuda") * 5)
|
||||
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
bias = torch.randn((N,), device="cuda", dtype=torch.float16)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl-kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
gbps = (
|
||||
lambda ms: (
|
||||
(2 * M * N * K - M * N) * a.element_size()
|
||||
+ (3 * M * N) * scale_a.element_size()
|
||||
)
|
||||
* 1e-9
|
||||
/ (ms * 1e-3)
|
||||
)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path="bench_int8_res")
|
||||
@@ -26,6 +26,7 @@ cutlass = root / "3rdparty" / "cutlass"
|
||||
include_dirs = [
|
||||
cutlass.resolve() / "include",
|
||||
cutlass.resolve() / "tools" / "util" / "include",
|
||||
root / "src" / "sgl-kernel" / "csrc",
|
||||
]
|
||||
nvcc_flags = [
|
||||
"-O3",
|
||||
@@ -48,6 +49,7 @@ ext_modules = [
|
||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
],
|
||||
include_dirs=include_dirs,
|
||||
|
||||
@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
|
||||
custom_dispose,
|
||||
custom_reduce,
|
||||
init_custom_reduce,
|
||||
int8_scaled_mm,
|
||||
moe_align_block_size,
|
||||
)
|
||||
|
||||
@@ -10,4 +11,5 @@ __all__ = [
|
||||
"init_custom_reduce",
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
"int8_scaled_mm",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
|
||||
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_,
|
||||
bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol {
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using ScaleTileIterator = ScaleTileIterator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
||||
|
||||
using ElementCompute = ElementCompute_;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_,
|
||||
int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_),
|
||||
batch_stride_alpha(batch_stride_alpha_),
|
||||
batch_stride_C(batch_stride_C_),
|
||||
batch_stride_D(batch_stride_D_) {}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args)
|
||||
: elementwise(args.elementwise),
|
||||
batch_stride_alpha(args.batch_stride_alpha),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D) {}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {};
|
||||
|
||||
private:
|
||||
Params const& params_;
|
||||
SharedStorage& shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
bool const with_bias_;
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
ScaleTileIterator iterator_alpha_col_;
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
|
||||
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
||||
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
||||
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator beta_;
|
||||
|
||||
int column_offset_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col,
|
||||
typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant,
|
||||
bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
|
||||
int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
: params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
extent_(problem_size),
|
||||
elementwise_(params.elementwise),
|
||||
with_bias_(with_bias),
|
||||
per_token_quant_(per_token_quant),
|
||||
per_channel_quant_(per_channel_quant),
|
||||
ptr_alpha_row_(ptr_alpha_row),
|
||||
ptr_alpha_col_(ptr_alpha_col),
|
||||
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
|
||||
extent_real_(problem_size_real) {
|
||||
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) {
|
||||
element_alpha_col_ = *ptr_alpha_col_;
|
||||
}
|
||||
|
||||
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) {
|
||||
element_alpha_row_ = *ptr_alpha_row_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (per_channel_quant_) {
|
||||
iterator_alpha_col_.load(fragment_alpha_col_);
|
||||
}
|
||||
|
||||
if (with_bias_) {
|
||||
iterator_C_.load(fragment_C_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
// load alpha_row in begin_step only when per token(row) scaling is used
|
||||
if (per_token_quant_) {
|
||||
int thread_offset_row =
|
||||
iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
|
||||
|
||||
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
||||
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
||||
}
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) {
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
||||
|
||||
ComputeFragment result = source_converter(accum);
|
||||
if (per_channel_quant_) {
|
||||
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
|
||||
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
||||
} else {
|
||||
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
||||
}
|
||||
|
||||
if (with_bias_) {
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kElementsPerAccess> bias_converter;
|
||||
OutputVector bias = reinterpret_cast<OutputVector*>(&fragment_C_)[column_idx];
|
||||
result = bias_accumulator_(result, bias_converter(bias));
|
||||
}
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col,
|
||||
AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
result[i] = accum[i] * (scale_col[i] * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col,
|
||||
AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
result[i] = accum[i] * (scale_col * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment bias_accumulator_(ComputeFragment const& accum, ComputeFragment const& bias) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputVector::kElements; ++i) {
|
||||
result[i] = accum[i] + bias[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,346 @@
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
|
||||
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
|
||||
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
|
||||
|
||||
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
|
||||
that feature at the moment.
|
||||
*/
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalBaseCompat {
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
|
||||
|
||||
using ElementB = typename GemmKernel::ElementB;
|
||||
using LayoutB = typename GemmKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename GemmKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
protected:
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) {
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
int const kAlignK =
|
||||
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBaseCompat() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args) {
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
|
||||
|
||||
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
|
||||
} else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10)) {
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
} else {
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount, 0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0) {
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes) {
|
||||
if (!workspace) {
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm) {
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result =
|
||||
cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,456 @@
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor {
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using TensorRefB = TensorRef<ElementB, LayoutB>;
|
||||
|
||||
using ElementCompute = typename EpilogueVisitor::ElementCompute;
|
||||
using LayoutAlphaCol = cutlass::layout::RowMajor;
|
||||
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
|
||||
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
|
||||
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
using TensorRefC = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
using EpilogueOutputOp =
|
||||
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefAlphaCol ref_alpha_col;
|
||||
TensorRefAlphaRow ref_alpha_row;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
: mode(GemmUniversalMode::kGemm),
|
||||
problem_size(problem_size_),
|
||||
batch_count(1),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
ref_alpha_col(ref_alpha_col_),
|
||||
ref_alpha_row(ref_alpha_row_),
|
||||
ref_C(ref_C_),
|
||||
ref_D(ref_D_),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_D(0),
|
||||
epilogue_visitor(epilogue_visitor_) {}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void* ptr_A;
|
||||
void* ptr_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_alpha_col(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_alpha_col(nullptr),
|
||||
ptr_alpha_row(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0) {}
|
||||
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
|
||||
: problem_size(args.problem_size),
|
||||
swizzle_log_tile(0),
|
||||
params_A(args.ref_A.layout()),
|
||||
params_B(args.ref_B.layout()),
|
||||
params_alpha_col(args.ref_alpha_col.layout()),
|
||||
params_alpha_row(args.ref_alpha_col.layout()),
|
||||
params_C(args.ref_C.layout()),
|
||||
params_D(args.ref_D.layout()),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(args.problem_size.k()),
|
||||
ptr_A(args.ref_A.data()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
ptr_alpha_col(args.ref_alpha_col.data()),
|
||||
ptr_alpha_row(args.ref_alpha_row.data()),
|
||||
ptr_C(args.ref_C.data()),
|
||||
ptr_D(args.ref_D.data()),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
epilogue_visitor(args.epilogue_visitor) {
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
int const kAlignK =
|
||||
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
|
||||
struct {
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithEpilogueVisitor() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) {
|
||||
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
|
||||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage) {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
|
||||
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
} else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
} else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
#endif
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
bool with_bias = true;
|
||||
if (params.ptr_C == nullptr) {
|
||||
with_bias = false;
|
||||
}
|
||||
|
||||
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(),
|
||||
thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
|
||||
params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col,
|
||||
params.ptr_C, params.ptr_D, threadblock_offset,
|
||||
blockIdx.y * params.problem_size.m());
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
} else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) {
|
||||
if constexpr (platform::is_same<ArchTag, CompilationArch>::value) {
|
||||
run_kernel_(params, shared_storage);
|
||||
} else {
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage) {
|
||||
run_kernel<ArchTag>(params, shared_storage);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
209
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
Normal file
209
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
Normal file
@@ -0,0 +1,209 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
|
||||
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
|
||||
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename ThreadblockShape, typename WarpShape,
|
||||
typename InstructionShape, int NumStages>
|
||||
void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
using ElementInputA = int8_t;
|
||||
using ElementInputB = int8_t;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
|
||||
|
||||
using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA,
|
||||
ElementInputB, ElementOutput, ElementCompute>;
|
||||
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
|
||||
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
||||
ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, ElementInputB,
|
||||
cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
|
||||
ThreadblockSwizzle, NumStages, true, typename DefaultGemmConf::Operator>::GemmKernel;
|
||||
|
||||
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits<ElementOutput>::value>,
|
||||
ElementCompute>;
|
||||
|
||||
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
|
||||
ThreadblockShape, GemmKernel_::kThreadCount, AlphaColTileIterator,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator, ElementAccumulator, ElementCompute, EpilogueOutputOp>;
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
|
||||
EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = mat_a.size(0);
|
||||
int k = mat_a.size(1);
|
||||
int n = mat_b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
|
||||
|
||||
int64_t lda = mat_a.stride(0);
|
||||
int64_t ldb = mat_b.stride(1);
|
||||
int64_t ldd = out.stride(0);
|
||||
|
||||
ElementOutput* bias_ptr = nullptr;
|
||||
int64_t ldc = 0;
|
||||
if (bias) {
|
||||
bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
|
||||
}
|
||||
|
||||
typename EpilogueOutputOp::Params linearScalingParams;
|
||||
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
|
||||
|
||||
typename Gemm::Arguments args{{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0},
|
||||
{a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
|
||||
|
||||
auto workspace = torch::empty(gemm_op.get_workspace_size(args),
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm75_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
if (m <= 32) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
} else if (m <= 64) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
} else if (m <= 256) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 16) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
}
|
||||
} else if (m <= 32) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64 || (m <= 128 && n < 8192)) {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
|
||||
scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(1) % 8 == 0, "mat_b.size(1) must be multiple of 8 for memory alignment"); // out.stride(0)
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
|
||||
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
|
||||
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
|
||||
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
|
||||
}
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
if (sm_version >= 75 && sm_version < 80) {
|
||||
TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75");
|
||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (sm_version >= 80 && sm_version <= 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability.");
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
@@ -12,6 +12,11 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
|
||||
|
||||
// int8_scaled_mm
|
||||
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// trt_reduce
|
||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||
@@ -19,4 +24,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
|
||||
// moe_align_block_size
|
||||
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
|
||||
// int8_scaled_mm
|
||||
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
|
||||
}
|
||||
|
||||
@@ -34,3 +34,13 @@ struct cuda_error : public std::runtime_error {
|
||||
#define CHECK_CUDA_INPUT(x) \
|
||||
CHECK_IS_CUDA(x); \
|
||||
CHECK_IS_CONTIGUOUS(x)
|
||||
|
||||
inline int getSMVersion() {
|
||||
int device{-1};
|
||||
CHECK_CUDA_SUCCESS(cudaGetDevice(&device));
|
||||
int sm_major = 0;
|
||||
int sm_minor = 0;
|
||||
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
|
||||
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
||||
from sgl_kernel.ops._kernels import dispose as _dispose
|
||||
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
||||
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
||||
|
||||
|
||||
@@ -36,3 +37,14 @@ def moe_align_block_size(
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return _int8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
56
sgl-kernel/tests/test_int8_gemm.py
Normal file
56
sgl-kernel/tests/test_int8_gemm.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
|
||||
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
|
||||
if bias is not None:
|
||||
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias
|
||||
else:
|
||||
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1)
|
||||
return o.to(out_dtype)
|
||||
|
||||
|
||||
class TestInt8Gemm(unittest.TestCase):
|
||||
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
|
||||
a = to_int8(torch.randn((M, K), device=device) * 5)
|
||||
b = to_int8(torch.randn((N, K), device=device).t() * 5)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
if with_bias:
|
||||
bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
torch.testing.assert_close(o, o1)
|
||||
torch.testing.assert_close(o, o2)
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 128, 512, 1024, 4096]
|
||||
Ns = [16, 128, 512, 1024, 4096]
|
||||
Ks = [512, 1024, 4096, 8192, 16384]
|
||||
bias_opts = [True, False]
|
||||
out_dtypes = [torch.float16, torch.bfloat16]
|
||||
for M in Ms:
|
||||
for N in Ns:
|
||||
for K in Ks:
|
||||
for with_bias in bias_opts:
|
||||
for out_dtype in out_dtypes:
|
||||
self._test_accuracy_once(
|
||||
M, N, K, with_bias, out_dtype, "cuda"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user