Rename files in sgl kernel to avoid nested folder structure (#4213)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Lianmin Zheng
2025-03-08 22:54:51 -08:00
committed by GitHub
parent ee132a4515
commit 8abf74e3c9
47 changed files with 184 additions and 199 deletions

View File

@@ -0,0 +1,172 @@
// References:
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLAS/Extensions/GemmGroupedBatchedEx/cublas_GemmGroupedBatchedEx_example.cu
// https://github.com/zhihu/ZhiLight/blob/main/src/nn/linear/gemm_grouped.cpp
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h>
#include <cublas_v2.h>
#include <cudaTypedefs.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <torch/extension.h>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include "utils.h"
static void check_group_count(
const std::vector<torch::Tensor>& inputs,
const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs) {
TORCH_CHECK(
((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
"The group count of inputs, weights and outputs should be the same.");
}
static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) {
for (const auto& t : tensors) {
TORCH_CHECK(dtype == t.dtype(), "dtype of all the tensors should be the same");
TORCH_CHECK(t.is_cuda(), "All tensors should be in Cuda memory");
}
}
static std::vector<int> get_dims(const std::vector<torch::Tensor>& tensors, int dim) {
std::vector<int> results;
for (const auto& t : tensors) {
TORCH_CHECK(t.dim() == 2, "Should pass in 2D matrices");
results.push_back(t.size(dim));
}
return std::move(results);
}
static std::vector<int> get_strides(const std::vector<torch::Tensor>& tensors, int dim) {
std::vector<int> results;
for (const auto& t : tensors) {
results.push_back(t.stride(dim));
}
return std::move(results);
}
static void check_equal(const std::vector<int>& a, const std::vector<int>& b, const std::string& err_msg) {
for (int i = 0; i < a.size(); ++i) {
TORCH_CHECK(a[i] == b[i], err_msg);
}
}
static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tensors) {
std::vector<void*> ptrs;
for (auto& t : tensors) {
ptrs.push_back(t.data_ptr());
}
return std::move(ptrs);
}
static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) {
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options);
TORCH_CHECK(
cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
CUBLAS_STATUS_SUCCESS);
return gpu_ptrs;
}
// We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed
void cublas_grouped_gemm(
const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
const torch::Dtype& out_dtype,
int64_t cublas_handle,
int64_t cuda_stream) {
TORCH_CHECK(
out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype");
int group_count = inputs.size();
check_group_count(inputs, weights, outputs);
std::vector<int> group_size(group_count, 1);
// Make sure all tensors are on cuda and use the same dtype
check_device_dtype(out_dtype, inputs);
check_device_dtype(out_dtype, weights);
check_device_dtype(out_dtype, outputs);
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
// Weights should be transposed to (n, k) of column major
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
std::vector<cublasOperation_t> transb_array(group_count, CUBLAS_OP_N);
// Get dim arrays
std::vector<int> m_array = get_dims(weights, 0);
std::vector<int> n_array = get_dims(inputs, 0);
std::vector<int> k_array = get_dims(inputs, 1);
// Make sure the dimensions in each group match
std::vector<int> m_array1 = get_dims(outputs, 1);
std::vector<int> n_array1 = get_dims(outputs, 0);
std::vector<int> k_array1 = get_dims(weights, 1);
check_equal(m_array, m_array1, "sizes don't match on m dimension");
check_equal(n_array, n_array1, "sizes don't match on n dimension");
check_equal(k_array, k_array1, "sizes don't match on k dimension");
// Get leading dimensions
std::vector<int> lda_array = get_strides(weights, 0);
std::vector<int> ldb_array = get_strides(inputs, 0);
std::vector<int> ldc_array = get_strides(outputs, 0);
// Use default scaling factors
std::vector<float> alpha_array(group_count, 1);
std::vector<float> beta_array(group_count, 0);
std::vector<void*> a_array = get_tensor_ptrs(weights);
std::vector<void*> b_array = get_tensor_ptrs(inputs);
std::vector<void*> c_array = get_tensor_ptrs(outputs);
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
// Should allocate tensors for storage of pointers
torch::Tensor d_a = create_ptr_pointer(a_array, stream);
torch::Tensor d_b = create_ptr_pointer(b_array, stream);
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto status = cublasGemmGroupedBatchedEx(
handle,
transa_array.data(),
transb_array.data(),
m_array.data(),
n_array.data(),
k_array.data(),
alpha_array.data(),
(void**)d_a.data_ptr(),
cuda_data_type,
lda_array.data(),
(void**)d_b.data_ptr(),
cuda_data_type,
ldb_array.data(),
beta_array.data(),
(void**)d_c.data_ptr(),
cuda_data_type,
ldc_array.data(),
group_count,
group_size.data(),
CUBLAS_COMPUTE_32F);
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
}

View File

@@ -0,0 +1,226 @@
#include <ATen/cuda/CUDAContext.h>
#include <cudaTypedefs.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/memory.h>
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/thread/mma.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "utils.h"
using namespace cute;
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
void launch_sm90_fp8_blockwise_scaled_mm(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
using ElementA = cutlass::float_e4m3_t;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = cutlass::float_e4m3_t;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
constexpr int AlignmentD = AlignmentC;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueSchedule,
StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
LayoutA,
AlignmentA,
ElementB,
LayoutB,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op;
int m = a.size(0);
int k = a.size(1);
int n = b.size(1);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
auto o_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_s_ptr = static_cast<ElementBlockScale*>(scales_a.data_ptr());
auto b_s_ptr = static_cast<ElementBlockScale*>(scales_b.data_ptr());
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr};
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
typename Gemm::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
mainloop_args,
epilogue_args,
};
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
}
template <typename OutType>
void sm90_fp8_blockwise_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
}
torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK(
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
auto is_contiguous_vector = [](const torch::Tensor& t) {
auto t_sizes = t.sizes();
return t.is_contiguous() &&
(t.dim() == 1 || (t.dim() == 2 && *std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
};
TORCH_CHECK(mat_a.size(0) == scales_a.size(0), "size of scales_a is not matched");
TORCH_CHECK(mat_a.size(1) / 128 == scales_a.size(1), "size of scales_a is not matched");
TORCH_CHECK(scales_a.stride(0) == 1 || is_contiguous_vector(scales_a), "scales_a must be M major");
TORCH_CHECK(mat_b.size(0) / 128 == scales_b.size(0), "size of scales_b is not matched");
TORCH_CHECK(mat_b.size(1) / 128 == scales_b.size(1), "size of scales_b is not matched");
TORCH_CHECK(scales_b.stride(0) == 1 || is_contiguous_vector(scales_b), "scales_b must be K major");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
auto sm_version = getSMVersion();
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version >= 90) {
if (out_dtype == torch::kBFloat16) {
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b);
} else {
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b);
}
return out;
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
}

View File

@@ -0,0 +1,859 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h
#include <ATen/cuda/CUDAContext.h>
#include <cudaTypedefs.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/memory.h>
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/thread/mma.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "utils.h"
using namespace cute;
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
template <
typename ElementType,
typename OutElementType,
typename AccumElementType,
typename CtaShape,
typename WarpShape,
int Stages,
bool WithBias,
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
struct DeviceGemmFp8RowwiseSm89 {
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
using ElementA = ElementType;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementType;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = OutElementType;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using ElementOutput = OutElementType;
using LayoutOutput = cutlass::layout::RowMajor;
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
using ElementAccumulator = AccumElementType;
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm89;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
// Number of epilogue stages in EVT
static constexpr int EVTEpilogueStages = 1;
using OutputTileThreadMap = cutlass::epilogue::threadblock::
OutputTileThreadLayout<CtaShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>;
// Definition of EVT
using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;
using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies,
ElementComputeEpilogue,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using bScaleSrc = cutlass::epilogue::threadblock::
VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_0, _1, _0>>;
using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;
using ComputeAScale = cutlass::epilogue::threadblock::
VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
using aScaleSrc = cutlass::epilogue::threadblock::
VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_1, _0, _0>>;
using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;
// With bias
using biasSrc =
cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add,
ElementC,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EpilogueAScaleWithBias =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;
using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap,
ElementC,
cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, _1, _0>>;
using EpilogueStore = typename cutlass::platform::conditional<
WithBias,
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
using EpilogueOp = EpilogueStore;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA,
LayoutA,
cutlass::ComplexTransform::kNone,
AlignmentA,
ElementB,
LayoutB,
cutlass::ComplexTransform::kNone,
AlignmentB,
ElementC,
LayoutC,
AlignmentC,
ElementAccumulator,
ElementComputeEpilogue,
OperatorClass,
ArchTag,
CtaShape,
WarpShape,
InstructionShape,
EpilogueOp,
ThreadblockSwizzle,
Stages,
FP8MathOperator,
EVTEpilogueStages>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
template <typename Gemm, bool WithBias>
typename Gemm::Arguments prepare_sm89_fp8_args(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementT = typename Gemm::ElementA;
using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
ElementOutput const* ptr_bias = nullptr;
if constexpr (WithBias) {
TORCH_CHECK(bias.has_value())
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
}
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
typename Gemm::Arguments args(
cutlass::gemm::GemmUniversalMode::kGemm, // Mode
{m, n, k}, // Problem size
1, // Split-k factor
{}, // Epilogue args
ptr_a, // a pointer
ptr_b, // b pointer
nullptr, // c pointer (unused)
nullptr, // d pointer (unused)
m * k, // batch stride a (unused)
n * k, // batch stride b (unused)
m * n, // batch stride c (unused)
m * n, // batch stride d (unused)
lda, // stride a
ldb, // stride b
ldc, // stride c (unused)
ldc); // stride d (unused)
if constexpr (WithBias) {
args.epilogue = {
{
{
{}, // Accumulator
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
{} // Multiplies
},
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
{ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
{} // Multiplies
},
{ptr_d, {n, _1{}, _0{}}}};
} else {
args.epilogue = {
{
{
{}, // Accumulator
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
{} // Multiplies
},
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
{} // Multiplies
},
{ptr_d, {n, _1{}, _0{}}}};
}
return args;
}
template <typename Gemm, bool WithBias>
void launch_sm89_fp8_scaled_mm(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
Gemm gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess)
}
template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
void sm89_fp8_dispatch_bias(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
if (bias) {
using Gemm = typename DeviceGemmFp8RowwiseSm89<
ElementInput,
ElementOutput,
AccumElementType,
CtaShape,
WarpShape,
Stages,
true>::Gemm;
return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else {
using Gemm = typename DeviceGemmFp8RowwiseSm89<
ElementInput,
ElementOutput,
AccumElementType,
CtaShape,
WarpShape,
Stages,
false>::Gemm;
return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
}
}
template <typename OutType>
void sm89_fp8_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0);
uint32_t const n = out.size(1);
if (m == 1) {
if (n <= 8192) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 16) {
// M in (1, 16]
if (n <= 8192) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
4>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 64) {
// M in (16, 64]
if (n <= 16384) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 128) {
// M in (64, 128]
if (n <= 8192) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
4>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 256) {
// M in (128, 256]
if (n <= 8192) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<128, 64, 128>,
cutlass::gemm::GemmShape<64, 32, 128>,
4>(out, a, b, scales_a, scales_b, bias);
}
} else if (m <= 512) {
// M in (256, 512)
if (n <= 16384) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
2>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
4>(out, a, b, scales_a, scales_b, bias);
}
} else {
// M in (512, inf)
if (n <= 8192) {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
3>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_fp8_dispatch_bias<
OutType,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
2>(out, a, b, scales_a, scales_b, bias);
}
}
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
template <
typename ElementType,
typename OutElementType,
typename AccumElementType,
typename CTAShape,
typename ClusterShape,
typename MainloopScheduleType,
typename EpilogueScheduleType,
typename TileSchedulerType = void,
bool WithBias = false>
struct DeviceGemmFp8RowwiseSm90 {
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
// A matrix configuration
using ElementA = ElementType; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = ElementType; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = void; // Element type for C matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<OutElementType>::value; // Memory access granularity/alignment of C matrices in
// units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = OutElementType; // Element type for output matrix operands
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// // Auxiliary matrix configuration and other fusion types
// using ElementBias = float;
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = AccumElementType; // Element type for internal accumulation
using ElementCompute = float; // Element type for compute
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = CTAShape; // Threadblock-level tile size
static constexpr bool PONG = false;
static constexpr bool FAST_ACCUM = true;
static constexpr bool USE_BIAS = false;
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized
// based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default
// setting in the Collective Builder
// Implement rowwise scaling epilogue.
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
0,
TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
ElementOutput,
ElementOutput,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
ElementComputeEpilogue, // First stage output type.
ElementComputeEpilogue, // First stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
ElementOutput,
ElementComputeEpilogue, // Second stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
// With bias
using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add,
ElementOutput,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementComputeEpilogue,
ElementC,
LayoutC,
AlignmentC,
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::TmaWarpSpecialized,
EpilogueEVT>::CollectiveOp;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using SlowAccum = DefaultSchedule;
using FastAccum = FastPongSchedule; // Default apply Pingpong
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
LayoutA,
AlignmentA,
ElementB,
LayoutB,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
TileSchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
template <typename Gemm, bool WithBias>
typename Gemm::Arguments prepare_sm90_fp8_args(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementT = typename Gemm::ElementA;
using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
ElementOutput const* ptr_bias = nullptr;
if constexpr (WithBias) {
TORCH_CHECK(bias.has_value())
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
}
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
typename Gemm::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{ptr_a, stride_a, ptr_b, stride_b},
{{}, // epilogue.thread
nullptr,
stride_c,
ptr_d,
stride_d}};
if constexpr (WithBias) {
args.epilogue.thread = {
{ptr_scales_a},
{
{ptr_scales_b},
{}, // Accumulator
{} // Multiplies
},
{ptr_bias},
{}, // Multiplies
};
} else {
args.epilogue.thread = {
{ptr_scales_a},
{
{ptr_scales_b},
{}, // Accumulator
{} // Multiplies
},
{}, // Multiplies
};
}
return args;
}
template <typename Gemm, bool WithBias>
void launch_sm90_fp8_scaled_mm(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
Gemm gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess)
}
template <
typename OutType,
typename CTAShape,
typename ClusterShape,
typename MainloopScheduleType,
typename TileSchedulerType>
void sm90_fp8_dispatch_bias(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias,
bool fast_accum = true,
bool use_persistent = false) {
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
if (bias) {
using Gemm = typename DeviceGemmFp8RowwiseSm90<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else {
using Gemm = typename DeviceGemmFp8RowwiseSm90<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
}
}
template <typename OutType>
void sm90_fp8_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0);
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
using BasicTileScheduler = void;
if (m <= 1) {
return sm90_fp8_dispatch_bias<
OutType,
Shape<_64, _64, _128>,
Shape<_1, _8, _1>,
FastBasicScheduler,
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
}
if (m <= 64) {
// m in [1, 64]
return sm90_fp8_dispatch_bias<
OutType,
Shape<_64, _64, _128>,
Shape<_1, _4, _1>,
FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 256) {
// m in (64, 256]
return sm90_fp8_dispatch_bias<
OutType,
Shape<_64, _64, _128>,
Shape<_1, _1, _1>,
FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 1024) {
// m in (256, 1024]
return sm90_fp8_dispatch_bias<
OutType,
Shape<_128, _128, _128>,
Shape<_1, _1, _1>,
FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (1024, inf)
return sm90_fp8_dispatch_bias<
OutType,
Shape<_128, _128, _128>,
Shape<_2, _1, _1>,
FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
}
}
#endif
torch::Tensor fp8_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK(
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
if (bias) {
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
}
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
auto sm_version = getSMVersion();
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version >= 90) {
if (out_dtype == torch::kBFloat16) {
sm90_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm90_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
return out;
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
if (sm_version == 89) {
if (out_dtype == torch::kBFloat16) {
sm89_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm89_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
return out;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
}

View File

@@ -0,0 +1,599 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/numeric_types.h>
#include <cute/atom/mma_atom.hpp>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
#include "utils.h"
using namespace cute;
template <
typename ElementOutput,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
int NumStages>
void cutlass_int8_scaled_mm(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementAccumulator = int32_t;
using ElementCompute = float;
using ElementInputA = int8_t;
using ElementInputB = int8_t;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
using DefaultGemmConf = cutlass::gemm::device::
DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA, ElementInputB, ElementOutput, ElementCompute>;
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
ElementInputA,
cutlass::layout::RowMajor,
DefaultGemmConf::kAlignmentA,
ElementInputB,
cutlass::layout::ColumnMajor,
DefaultGemmConf::kAlignmentB,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
NumStages,
true,
typename DefaultGemmConf::Operator>::GemmKernel;
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess,
cutlass::sizeof_bits<ElementOutput>::value>,
ElementCompute>;
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
ThreadblockShape,
GemmKernel_::kThreadCount,
AlphaColTileIterator,
typename GemmKernel_::Epilogue::OutputTileIterator,
ElementAccumulator,
ElementCompute,
EpilogueOutputOp>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpilogueWithVisitorFromExistingEpilogue<EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
using GemmKernel =
cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
Gemm gemm_op;
int m = mat_a.size(0);
int k = mat_a.size(1);
int n = mat_b.size(1);
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
int64_t lda = mat_a.stride(0);
int64_t ldb = mat_b.stride(1);
int64_t ldd = out.stride(0);
ElementOutput* bias_ptr = nullptr;
int64_t ldc = 0;
if (bias) {
bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
}
typename EpilogueOutputOp::Params linearScalingParams;
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
typename Gemm::Arguments args{
{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
auto workspace = torch::empty(
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(
can_implement == cutlass::Status::kSuccess,
"gemm cannot implement, error: ",
cutlassGetStatusString(can_implement));
auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
}
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm75_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0);
if (m <= 32) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (m <= 64) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (m <= 256) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
}
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm80_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0);
int n = mat_b.size(1);
if (m <= 16) {
if (n <= 4096) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 32) {
if (n <= 4096) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 64) {
if (n <= 4096) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 128 && n < 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
}
template <
typename ElementOutput,
typename TileShape,
typename ClusterShape,
typename MainloopScheduleType,
bool WithBias>
void cutlass_int8_scaled_mm_sm90(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ArchTag = cutlass::arch::Sm90;
using ElementAccumulator = int32_t;
using ElementCompute = float;
using ElementInputA = int8_t;
using ElementInputB = int8_t;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementOutput>::value;
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
using TileSchedulerType = cutlass::gemm::PersistentScheduler;
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<1>, Int<0>, Int<0>>>;
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<0>, Int<1>, Int<0>>>;
using Bias = cutlass::epilogue::fusion::
Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride<Int<0>, Int<1>, Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
// Scale
using Compute0 = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
using Compute1 = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
// With bias
using ComputeWithBias = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementOutput,
cutlass::layout::RowMajor,
AlignmentC,
ElementOutput,
cutlass::layout::RowMajor,
AlignmentOutput,
EpilogueScheduleType,
EpilogueEVT>::CollectiveOp;
using Stages = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
cutlass::layout::RowMajor,
AlignmentA,
ElementInputB,
cutlass::layout::ColumnMajor,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
Stages,
MainloopScheduleType>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
TileSchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op;
int m = mat_a.size(0);
int k = mat_a.size(1);
int n = mat_b.size(1);
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
typename Gemm::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{a_ptr, stride_a, b_ptr, stride_b},
{{}, // epilogue.thread
nullptr,
stride_c,
o_ptr,
stride_d}};
if constexpr (WithBias) {
ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
args.epilogue.thread = {
{a_s_ptr},
{{b_s_ptr}, {}, {}},
{bias_ptr},
{},
};
} else {
args.epilogue.thread = {
{a_s_ptr},
{{b_s_ptr}, {}, {}},
{},
};
}
auto workspace = torch::empty(
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(
can_implement == cutlass::Status::kSuccess,
"gemm cannot implement, error: ",
cutlassGetStatusString(can_implement));
auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
}
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType>
void sm90_dispatch_bias(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
if (bias) {
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, true>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, false>(
out, mat_a, mat_b, scales_a, scales_b, bias);
}
}
template <typename ElementOutput>
void sm90_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0);
int n = mat_b.size(1);
if (m <= 32) {
if (n < 8192) {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _64, _128>,
Shape<_1, _8, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _128, _128>,
Shape<_1, _8, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 64) {
if (n < 8192) {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _64, _128>,
Shape<_1, _4, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _64, _256>,
Shape<_1, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (m <= 128) {
if (n <= 4096) {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _64, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _128, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else {
return sm90_dispatch_bias<
ElementOutput,
Shape<_128, _128, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
}
torch::Tensor int8_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment");
TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment");
TORCH_CHECK(mat_b.size(1) % 8 == 0, "mat_b.size(1) must be multiple of 8 for memory alignment"); // out.stride(0)
TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8");
TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
if (bias) {
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
}
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
auto sm_version = getSMVersion();
if (sm_version >= 75 && sm_version < 80) {
TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75");
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (sm_version >= 80 && sm_version < 90) {
if (out_dtype == torch::kBFloat16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (sm_version == 90) {
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// cutlass 3.x
if (out_dtype == torch::kBFloat16) {
sm90_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm90_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
#else
// fallback to cutlass 2.x
if (out_dtype == torch::kBFloat16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
}
#endif
} else {
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability.");
}
return out;
}

View File

@@ -0,0 +1,125 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <cub/block/block_reduce.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
template <typename T>
__global__ void
per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
float max_value = 0.0f;
unsigned int tid = threadIdx.x;
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x;
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = num_elements / vec_size;
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
vec_t input_vec;
input_vec.cast_load(input + i * vec_size);
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = static_cast<float>(input_vec[j]);
max_value = fmaxf(max_value, fabsf(val));
}
}
const int32_t remaining_start = num_vec_elems * vec_size;
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
float val = static_cast<float>(input[idx]);
max_value = fmaxf(max_value, fabsf(val));
}
max_value = blockReduceMax(max_value);
if (tid == 0) {
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
}
}
template <typename T>
__global__ void per_tensor_quant_fp8_kernel(
const T* __restrict__ input,
FP8_TYPE* __restrict__ output,
const float* __restrict__ scale,
const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x;
const float scale_val = 1.0f / (*scale);
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = num_elements / vec_size;
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
vec_t input_vec;
input_vec.cast_load(input + i * vec_size);
FP8_TYPE output_arr[vec_size];
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
#else
output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
output[i * vec_size + j] = output_arr[j];
}
}
const int32_t remaining_start = num_vec_elems * vec_size;
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM
output[idx] = static_cast<FP8_TYPE>(val);
#else
output[idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
}
void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, bool is_static) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
const int block_size = 256;
const int num_elements = input.numel();
const int num_blocks = min((num_elements + block_size - 1) / block_size, 1024);
dim3 grid(num_blocks);
dim3 block(block_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
if (is_static == false) {
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
}
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
num_elements);
return true;
});
}

View File

@@ -0,0 +1,105 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include "utils.h"
using FP8_TYPE = c10::Float8_e4m3fn;
__device__ __forceinline__ float GroupReduceMax(volatile float* smem, const int tid) {
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
return smem[0];
}
template <typename T>
__global__ void per_token_group_quant_fp8_kernel(
const T* __restrict__ input,
void* __restrict__ output_q,
float* __restrict__ output_s,
const int group_size,
const int num_groups,
const float eps,
const float fp8_min,
const float fp8_max) {
const int groups_per_block = 16;
const int block_group_id = blockIdx.x * groups_per_block;
const int tid = threadIdx.x;
const int local_group_id = tid / 16;
const int local_tid = tid % 16;
__shared__ float s_absmax[16][17];
float local_absmax = eps;
if (block_group_id + local_group_id < num_groups) {
const T* group_input = input + (block_group_id + local_group_id) * group_size;
FP8_TYPE* group_output = static_cast<FP8_TYPE*>(output_q) + (block_group_id + local_group_id) * group_size;
float* scale_output = output_s + block_group_id + local_group_id;
for (int i = local_tid; i < group_size; i += 16) {
float val = static_cast<float>(group_input[i]);
float abs_val = fabsf(val);
local_absmax = fmaxf(local_absmax, abs_val);
}
s_absmax[local_group_id][local_tid] = local_absmax;
__syncthreads();
if (local_tid < 8) {
GroupReduceMax(&s_absmax[local_group_id][0], local_tid);
}
__syncthreads();
const float group_absmax = s_absmax[local_group_id][0];
const float y_s = group_absmax / fp8_max;
if (local_tid == 0) {
*scale_output = y_s;
}
for (int i = local_tid; i < group_size; i += 16) {
float val = static_cast<float>(group_input[i]);
float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max);
group_output[i] = FP8_TYPE(q_val);
}
}
}
void sgl_per_token_group_quant_fp8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
const int num_groups = input.numel() / group_size;
CHECK_EQ(input.numel() % group_size, 0);
dim3 grid((num_groups + 15) / 16);
dim3 block(256);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()),
output_q.data_ptr(),
static_cast<float*>(output_s.data_ptr()),
group_size,
num_groups,
(float)eps,
(float)fp8_min,
(float)fp8_max);
return true;
});
}

View File

@@ -0,0 +1,111 @@
#include <ATen/cuda/CUDAContext.h>
#include <cmath>
#include <cub/block/block_reduce.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
template <typename T>
__global__ void per_token_quant_fp8_kernel(
const T* __restrict__ input,
FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s,
const int64_t hidden_dim,
const int64_t num_tokens) {
const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return;
const int tid = threadIdx.x;
const int block_dim = blockDim.x;
const T* token_input = input + token_idx * hidden_dim;
FP8_TYPE* token_output = output_q + token_idx * hidden_dim;
float max_value = 0.0f;
for (int i = tid; i < hidden_dim; i += block_dim) {
float val = static_cast<float>(token_input[i]);
max_value = fmaxf(max_value, fabsf(val));
}
max_value = blockReduceMax(max_value);
__shared__ float block_max;
if (tid == 0) {
block_max = max_value / FP8_E4M3_MAX;
output_s[token_idx] = block_max;
}
__syncthreads();
const float scale_val = 1.0f / block_max;
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = hidden_dim / vec_size;
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
input_vec.cast_load(token_input + i * vec_size);
FP8_TYPE output_arr[vec_size];
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
#else
output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
token_output[i * vec_size + j] = output_arr[j];
}
}
const int32_t remaining_start = num_vec_elems * vec_size;
for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) {
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(token_input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM
token_output[idx] = static_cast<FP8_TYPE>(val);
#else
token_output[idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
}
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
const auto input_sizes = input.sizes();
const int64_t num_tokens = input_sizes[0];
const int64_t hidden_dim = input_sizes[1];
const int block_size = 128;
const int num_blocks = num_tokens;
dim3 grid(num_blocks);
dim3 block(block_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
return true;
});
}