Rename files in sgl kernel to avoid nested folder structure (#4213)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
172
sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
Normal file
172
sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
Normal file
@@ -0,0 +1,172 @@
|
||||
// References:
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex
|
||||
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLAS/Extensions/GemmGroupedBatchedEx/cublas_GemmGroupedBatchedEx_example.cu
|
||||
// https://github.com/zhihu/ZhiLight/blob/main/src/nn/linear/gemm_grouped.cpp
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
static void check_group_count(
|
||||
const std::vector<torch::Tensor>& inputs,
|
||||
const std::vector<torch::Tensor>& weights,
|
||||
const std::vector<torch::Tensor>& outputs) {
|
||||
TORCH_CHECK(
|
||||
((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
|
||||
"The group count of inputs, weights and outputs should be the same.");
|
||||
}
|
||||
|
||||
static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) {
|
||||
for (const auto& t : tensors) {
|
||||
TORCH_CHECK(dtype == t.dtype(), "dtype of all the tensors should be the same");
|
||||
TORCH_CHECK(t.is_cuda(), "All tensors should be in Cuda memory");
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int> get_dims(const std::vector<torch::Tensor>& tensors, int dim) {
|
||||
std::vector<int> results;
|
||||
for (const auto& t : tensors) {
|
||||
TORCH_CHECK(t.dim() == 2, "Should pass in 2D matrices");
|
||||
results.push_back(t.size(dim));
|
||||
}
|
||||
return std::move(results);
|
||||
}
|
||||
|
||||
static std::vector<int> get_strides(const std::vector<torch::Tensor>& tensors, int dim) {
|
||||
std::vector<int> results;
|
||||
for (const auto& t : tensors) {
|
||||
results.push_back(t.stride(dim));
|
||||
}
|
||||
return std::move(results);
|
||||
}
|
||||
|
||||
static void check_equal(const std::vector<int>& a, const std::vector<int>& b, const std::string& err_msg) {
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
TORCH_CHECK(a[i] == b[i], err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tensors) {
|
||||
std::vector<void*> ptrs;
|
||||
for (auto& t : tensors) {
|
||||
ptrs.push_back(t.data_ptr());
|
||||
}
|
||||
return std::move(ptrs);
|
||||
}
|
||||
|
||||
static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) {
|
||||
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
|
||||
torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options);
|
||||
TORCH_CHECK(
|
||||
cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
|
||||
CUBLAS_STATUS_SUCCESS);
|
||||
return gpu_ptrs;
|
||||
}
|
||||
|
||||
// We want compute input @ weight^T in row major
|
||||
// This is equivalent to computing weight @ input^T in col major
|
||||
// Cublas only accepts matrix in column major, so this arrangement is needed
|
||||
void cublas_grouped_gemm(
|
||||
const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
|
||||
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
|
||||
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
|
||||
const torch::Dtype& out_dtype,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream) {
|
||||
TORCH_CHECK(
|
||||
out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
|
||||
"cublas grouped_gemm can"
|
||||
"only be applied to float16 and bfloat16 dtype");
|
||||
|
||||
int group_count = inputs.size();
|
||||
check_group_count(inputs, weights, outputs);
|
||||
std::vector<int> group_size(group_count, 1);
|
||||
|
||||
// Make sure all tensors are on cuda and use the same dtype
|
||||
check_device_dtype(out_dtype, inputs);
|
||||
check_device_dtype(out_dtype, weights);
|
||||
check_device_dtype(out_dtype, outputs);
|
||||
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
||||
|
||||
// Weights should be transposed to (n, k) of column major
|
||||
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
|
||||
std::vector<cublasOperation_t> transb_array(group_count, CUBLAS_OP_N);
|
||||
|
||||
// Get dim arrays
|
||||
std::vector<int> m_array = get_dims(weights, 0);
|
||||
std::vector<int> n_array = get_dims(inputs, 0);
|
||||
std::vector<int> k_array = get_dims(inputs, 1);
|
||||
|
||||
// Make sure the dimensions in each group match
|
||||
std::vector<int> m_array1 = get_dims(outputs, 1);
|
||||
std::vector<int> n_array1 = get_dims(outputs, 0);
|
||||
std::vector<int> k_array1 = get_dims(weights, 1);
|
||||
check_equal(m_array, m_array1, "sizes don't match on m dimension");
|
||||
check_equal(n_array, n_array1, "sizes don't match on n dimension");
|
||||
check_equal(k_array, k_array1, "sizes don't match on k dimension");
|
||||
|
||||
// Get leading dimensions
|
||||
std::vector<int> lda_array = get_strides(weights, 0);
|
||||
std::vector<int> ldb_array = get_strides(inputs, 0);
|
||||
std::vector<int> ldc_array = get_strides(outputs, 0);
|
||||
|
||||
// Use default scaling factors
|
||||
std::vector<float> alpha_array(group_count, 1);
|
||||
std::vector<float> beta_array(group_count, 0);
|
||||
|
||||
std::vector<void*> a_array = get_tensor_ptrs(weights);
|
||||
std::vector<void*> b_array = get_tensor_ptrs(inputs);
|
||||
std::vector<void*> c_array = get_tensor_ptrs(outputs);
|
||||
|
||||
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
||||
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
|
||||
// Should allocate tensors for storage of pointers
|
||||
torch::Tensor d_a = create_ptr_pointer(a_array, stream);
|
||||
torch::Tensor d_b = create_ptr_pointer(b_array, stream);
|
||||
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||
auto status = cublasGemmGroupedBatchedEx(
|
||||
handle,
|
||||
transa_array.data(),
|
||||
transb_array.data(),
|
||||
m_array.data(),
|
||||
n_array.data(),
|
||||
k_array.data(),
|
||||
alpha_array.data(),
|
||||
(void**)d_a.data_ptr(),
|
||||
cuda_data_type,
|
||||
lda_array.data(),
|
||||
(void**)d_b.data_ptr(),
|
||||
cuda_data_type,
|
||||
ldb_array.data(),
|
||||
beta_array.data(),
|
||||
(void**)d_c.data_ptr(),
|
||||
cuda_data_type,
|
||||
ldc_array.data(),
|
||||
group_count,
|
||||
group_size.data(),
|
||||
CUBLAS_COMPUTE_32F);
|
||||
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
|
||||
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
|
||||
return;
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
|
||||
}
|
||||
226
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
Normal file
226
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
Normal file
@@ -0,0 +1,226 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/activation.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
|
||||
#include <cutlass/gemm/thread/mma.h>
|
||||
#include <cutlass/layout/matrix.h>
|
||||
#include <cutlass/matrix_coord.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
#include <cutlass/epilogue/collective/default_epilogue.hpp>
|
||||
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
|
||||
void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
int n = b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementBlockScale*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementBlockScale*>(scales_b.data_ptr());
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr};
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
|
||||
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
};
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))
|
||||
|
||||
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_blockwise_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
|
||||
}
|
||||
|
||||
torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK(
|
||||
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
||||
auto t_sizes = t.sizes();
|
||||
return t.is_contiguous() &&
|
||||
(t.dim() == 1 || (t.dim() == 2 && *std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
||||
};
|
||||
|
||||
TORCH_CHECK(mat_a.size(0) == scales_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(mat_a.size(1) / 128 == scales_a.size(1), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_a.stride(0) == 1 || is_contiguous_vector(scales_a), "scales_a must be M major");
|
||||
TORCH_CHECK(mat_b.size(0) / 128 == scales_b.size(0), "size of scales_b is not matched");
|
||||
TORCH_CHECK(mat_b.size(1) / 128 == scales_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_b.stride(0) == 1 || is_contiguous_vector(scales_b), "scales_b must be K major");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
if (sm_version >= 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b);
|
||||
} else {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
||||
}
|
||||
859
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
Normal file
859
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
Normal file
@@ -0,0 +1,859 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/activation.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
|
||||
#include <cutlass/gemm/thread/mma.h>
|
||||
#include <cutlass/layout/matrix.h>
|
||||
#include <cutlass/matrix_coord.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
#include <cutlass/epilogue/collective/default_epilogue.hpp>
|
||||
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
||||
template <
|
||||
typename ElementType,
|
||||
typename OutElementType,
|
||||
typename AccumElementType,
|
||||
typename CtaShape,
|
||||
typename WarpShape,
|
||||
int Stages,
|
||||
bool WithBias,
|
||||
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
|
||||
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
|
||||
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
|
||||
struct DeviceGemmFp8RowwiseSm89 {
|
||||
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||
|
||||
using ElementA = ElementType;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = ElementType;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementC = OutElementType;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ElementOutput = OutElementType;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
using ElementAccumulator = AccumElementType;
|
||||
using ElementComputeEpilogue = float;
|
||||
using ArchTag = cutlass::arch::Sm89;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
// Number of epilogue stages in EVT
|
||||
static constexpr int EVTEpilogueStages = 1;
|
||||
|
||||
using OutputTileThreadMap = cutlass::epilogue::threadblock::
|
||||
OutputTileThreadLayout<CtaShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>;
|
||||
|
||||
// Definition of EVT
|
||||
using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using bScaleSrc = cutlass::epilogue::threadblock::
|
||||
VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_0, _1, _0>>;
|
||||
using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;
|
||||
|
||||
using ComputeAScale = cutlass::epilogue::threadblock::
|
||||
VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using aScaleSrc = cutlass::epilogue::threadblock::
|
||||
VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_1, _0, _0>>;
|
||||
using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;
|
||||
|
||||
// With bias
|
||||
using biasSrc =
|
||||
cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
|
||||
using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add,
|
||||
ElementC,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EpilogueAScaleWithBias =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;
|
||||
|
||||
using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap,
|
||||
ElementC,
|
||||
cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, _1, _0>>;
|
||||
using EpilogueStore = typename cutlass::platform::conditional<
|
||||
WithBias,
|
||||
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
|
||||
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
|
||||
|
||||
using EpilogueOp = EpilogueStore;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
AlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
FP8MathOperator,
|
||||
EVTEpilogueStages>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
typename Gemm::Arguments prepare_sm89_fp8_args(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementT = typename Gemm::ElementA;
|
||||
using ElementOutput = typename Gemm::ElementD;
|
||||
using ElementComputeEpilogue = float;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
|
||||
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
|
||||
ElementOutput const* ptr_bias = nullptr;
|
||||
if constexpr (WithBias) {
|
||||
TORCH_CHECK(bias.has_value())
|
||||
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
|
||||
}
|
||||
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
|
||||
|
||||
typename Gemm::Arguments args(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, // Mode
|
||||
{m, n, k}, // Problem size
|
||||
1, // Split-k factor
|
||||
{}, // Epilogue args
|
||||
ptr_a, // a pointer
|
||||
ptr_b, // b pointer
|
||||
nullptr, // c pointer (unused)
|
||||
nullptr, // d pointer (unused)
|
||||
m * k, // batch stride a (unused)
|
||||
n * k, // batch stride b (unused)
|
||||
m * n, // batch stride c (unused)
|
||||
m * n, // batch stride d (unused)
|
||||
lda, // stride a
|
||||
ldb, // stride b
|
||||
ldc, // stride c (unused)
|
||||
ldc); // stride d (unused)
|
||||
if constexpr (WithBias) {
|
||||
args.epilogue = {
|
||||
{
|
||||
{
|
||||
{}, // Accumulator
|
||||
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||
{ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_d, {n, _1{}, _0{}}}};
|
||||
} else {
|
||||
args.epilogue = {
|
||||
{
|
||||
{
|
||||
{}, // Accumulator
|
||||
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_d, {n, _1{}, _0{}}}};
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
void launch_sm89_fp8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||
}
|
||||
|
||||
template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
|
||||
void sm89_fp8_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementInput = cutlass::float_e4m3_t;
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
if (bias) {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm89<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
Stages,
|
||||
true>::Gemm;
|
||||
return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm89<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CtaShape,
|
||||
WarpShape,
|
||||
Stages,
|
||||
false>::Gemm;
|
||||
return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm89_fp8_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const n = out.size(1);
|
||||
|
||||
if (m == 1) {
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 16) {
|
||||
// M in (1, 16]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
// M in (16, 64]
|
||||
if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
// M in (64, 128]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 256) {
|
||||
// M in (128, 256]
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
5>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
7>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 64, 128>,
|
||||
cutlass::gemm::GemmShape<64, 32, 128>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 512) {
|
||||
// M in (256, 512)
|
||||
if (n <= 16384) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
2>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
4>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
// M in (512, inf)
|
||||
if (n <= 8192) {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
3>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm89_fp8_dispatch_bias<
|
||||
OutType,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>,
|
||||
2>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
template <
|
||||
typename ElementType,
|
||||
typename OutElementType,
|
||||
typename AccumElementType,
|
||||
typename CTAShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
typename EpilogueScheduleType,
|
||||
typename TileSchedulerType = void,
|
||||
bool WithBias = false>
|
||||
struct DeviceGemmFp8RowwiseSm90 {
|
||||
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = ElementType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = ElementType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = void; // Element type for C matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<OutElementType>::value; // Memory access granularity/alignment of C matrices in
|
||||
// units of elements (up to 16 bytes)
|
||||
|
||||
// Output matrix configuration
|
||||
using ElementOutput = OutElementType; // Element type for output matrix operands
|
||||
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
|
||||
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
// // Auxiliary matrix configuration and other fusion types
|
||||
// using ElementBias = float;
|
||||
|
||||
// Multiply-accumulate blocking/pipelining details
|
||||
using ElementAccumulator = AccumElementType; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for compute
|
||||
using ElementComputeEpilogue = float;
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = CTAShape; // Threadblock-level tile size
|
||||
|
||||
static constexpr bool PONG = false;
|
||||
static constexpr bool FAST_ACCUM = true;
|
||||
static constexpr bool USE_BIAS = false;
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized
|
||||
// based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default
|
||||
// setting in the Collective Builder
|
||||
// Implement rowwise scaling epilogue.
|
||||
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementOutput,
|
||||
ElementOutput,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
ElementComputeEpilogue, // First stage output type.
|
||||
ElementComputeEpilogue, // First stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
ElementOutput,
|
||||
ElementComputeEpilogue, // Second stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
// With bias
|
||||
using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add,
|
||||
ElementOutput,
|
||||
ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
|
||||
|
||||
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
AlignmentOutput,
|
||||
cutlass::epilogue::TmaWarpSpecialized,
|
||||
EpilogueEVT>::CollectiveOp;
|
||||
|
||||
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
|
||||
using SlowAccum = DefaultSchedule;
|
||||
using FastAccum = FastPongSchedule; // Default apply Pingpong
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
TileSchedulerType>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
typename Gemm::Arguments prepare_sm90_fp8_args(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementT = typename Gemm::ElementA;
|
||||
using ElementOutput = typename Gemm::ElementD;
|
||||
using ElementComputeEpilogue = float;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
|
||||
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
|
||||
ElementOutput const* ptr_bias = nullptr;
|
||||
if constexpr (WithBias) {
|
||||
TORCH_CHECK(bias.has_value())
|
||||
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
|
||||
}
|
||||
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
|
||||
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{ptr_a, stride_a, ptr_b, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
ptr_d,
|
||||
stride_d}};
|
||||
if constexpr (WithBias) {
|
||||
args.epilogue.thread = {
|
||||
{ptr_scales_a},
|
||||
{
|
||||
{ptr_scales_b},
|
||||
{}, // Accumulator
|
||||
{} // Multiplies
|
||||
},
|
||||
{ptr_bias},
|
||||
{}, // Multiplies
|
||||
};
|
||||
} else {
|
||||
args.epilogue.thread = {
|
||||
{ptr_scales_a},
|
||||
{
|
||||
{ptr_scales_b},
|
||||
{}, // Accumulator
|
||||
{} // Multiplies
|
||||
},
|
||||
{}, // Multiplies
|
||||
};
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
void launch_sm90_fp8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
|
||||
|
||||
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||
}
|
||||
|
||||
template <
|
||||
typename OutType,
|
||||
typename CTAShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
typename TileSchedulerType>
|
||||
void sm90_fp8_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias,
|
||||
bool fast_accum = true,
|
||||
bool use_persistent = false) {
|
||||
using ElementInput = cutlass::float_e4m3_t;
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
|
||||
|
||||
if (bias) {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm90<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>::Gemm;
|
||||
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
using Gemm = typename DeviceGemmFp8RowwiseSm90<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>::Gemm;
|
||||
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
uint32_t const m = a.size(0);
|
||||
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
using BasicTileScheduler = void;
|
||||
if (m <= 1) {
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
FastBasicScheduler,
|
||||
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
if (m <= 64) {
|
||||
// m in [1, 64]
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _4, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (m <= 256) {
|
||||
// m in (64, 256]
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (m <= 1024) {
|
||||
// m in (256, 1024]
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_1, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
// m in (1024, inf)
|
||||
return sm90_fp8_dispatch_bias<
|
||||
OutType,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
FastPingpongScheduler,
|
||||
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
torch::Tensor fp8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
|
||||
TORCH_CHECK(
|
||||
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(
|
||||
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
|
||||
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
|
||||
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
|
||||
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
|
||||
}
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
if (sm_version >= 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm90_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
||||
if (sm_version == 89) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm89_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm89_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
|
||||
}
|
||||
599
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
Normal file
599
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
Normal file
@@ -0,0 +1,599 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
|
||||
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
|
||||
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
int NumStages>
|
||||
void cutlass_int8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
using ElementInputA = int8_t;
|
||||
using ElementInputB = int8_t;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
|
||||
|
||||
using DefaultGemmConf = cutlass::gemm::device::
|
||||
DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA, ElementInputB, ElementOutput, ElementCompute>;
|
||||
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
|
||||
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
||||
ElementInputA,
|
||||
cutlass::layout::RowMajor,
|
||||
DefaultGemmConf::kAlignmentA,
|
||||
ElementInputB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
DefaultGemmConf::kAlignmentB,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
NumStages,
|
||||
true,
|
||||
typename DefaultGemmConf::Operator>::GemmKernel;
|
||||
|
||||
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
|
||||
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess,
|
||||
cutlass::sizeof_bits<ElementOutput>::value>,
|
||||
ElementCompute>;
|
||||
|
||||
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
|
||||
ThreadblockShape,
|
||||
GemmKernel_::kThreadCount,
|
||||
AlphaColTileIterator,
|
||||
typename GemmKernel_::Epilogue::OutputTileIterator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
EpilogueOutputOp>;
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpilogueWithVisitorFromExistingEpilogue<EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = mat_a.size(0);
|
||||
int k = mat_a.size(1);
|
||||
int n = mat_b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
|
||||
|
||||
int64_t lda = mat_a.stride(0);
|
||||
int64_t ldb = mat_b.stride(1);
|
||||
int64_t ldd = out.stride(0);
|
||||
|
||||
ElementOutput* bias_ptr = nullptr;
|
||||
int64_t ldc = 0;
|
||||
if (bias) {
|
||||
bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
|
||||
}
|
||||
|
||||
typename EpilogueOutputOp::Params linearScalingParams;
|
||||
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
|
||||
|
||||
typename Gemm::Arguments args{
|
||||
{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
|
||||
|
||||
auto workspace = torch::empty(
|
||||
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(
|
||||
can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ",
|
||||
cutlassGetStatusString(can_implement));
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm75_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
if (m <= 32) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (m <= 64) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (m <= 256) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
|
||||
void sm80_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 16) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 32) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n <= 4096) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128 && n < 8192) {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename TileShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
bool WithBias>
|
||||
void cutlass_int8_scaled_mm_sm90(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
using ElementInputA = int8_t;
|
||||
using ElementInputB = int8_t;
|
||||
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileSchedulerType = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using XScale = cutlass::epilogue::fusion::
|
||||
Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
// Scale
|
||||
using Compute0 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
// With bias
|
||||
using ComputeWithBias = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
|
||||
|
||||
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentC,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentOutput,
|
||||
EpilogueScheduleType,
|
||||
EpilogueEVT>::CollectiveOp;
|
||||
|
||||
using Stages = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementInputA,
|
||||
cutlass::layout::RowMajor,
|
||||
AlignmentA,
|
||||
ElementInputB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
Stages,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
TileSchedulerType>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = mat_a.size(0);
|
||||
int k = mat_a.size(1);
|
||||
int n = mat_b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementInputA*>(mat_a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementInputB*>(mat_b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementOutput*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementCompute*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementCompute*>(scales_b.data_ptr());
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
|
||||
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{a_ptr, stride_a, b_ptr, stride_b},
|
||||
{{}, // epilogue.thread
|
||||
nullptr,
|
||||
stride_c,
|
||||
o_ptr,
|
||||
stride_d}};
|
||||
|
||||
if constexpr (WithBias) {
|
||||
ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
|
||||
args.epilogue.thread = {
|
||||
{a_s_ptr},
|
||||
{{b_s_ptr}, {}, {}},
|
||||
{bias_ptr},
|
||||
{},
|
||||
};
|
||||
} else {
|
||||
args.epilogue.thread = {
|
||||
{a_s_ptr},
|
||||
{{b_s_ptr}, {}, {}},
|
||||
{},
|
||||
};
|
||||
}
|
||||
|
||||
auto workspace = torch::empty(
|
||||
gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(
|
||||
can_implement == cutlass::Status::kSuccess,
|
||||
"gemm cannot implement, error: ",
|
||||
cutlassGetStatusString(can_implement));
|
||||
|
||||
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType>
|
||||
void sm90_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
if (bias) {
|
||||
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, true>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, false>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementOutput>
|
||||
void sm90_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
int m = mat_a.size(0);
|
||||
int n = mat_b.size(1);
|
||||
if (m <= 32) {
|
||||
if (n < 8192) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _128, _128>,
|
||||
Shape<_1, _8, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 64) {
|
||||
if (n < 8192) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_1, _4, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _256>,
|
||||
Shape<_1, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (m <= 128) {
|
||||
if (n <= 4096) {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _64, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_64, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
return sm90_dispatch_bias<
|
||||
ElementOutput,
|
||||
Shape<_128, _128, _128>,
|
||||
Shape<_2, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor int8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||
TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment");
|
||||
TORCH_CHECK(mat_b.size(1) % 8 == 0, "mat_b.size(1) must be multiple of 8 for memory alignment"); // out.stride(0)
|
||||
TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8");
|
||||
TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8");
|
||||
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||
|
||||
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
|
||||
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
|
||||
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
|
||||
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
|
||||
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
|
||||
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
|
||||
}
|
||||
|
||||
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
if (sm_version >= 75 && sm_version < 80) {
|
||||
TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75");
|
||||
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else if (sm_version >= 80 && sm_version < 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else if (sm_version == 90) {
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
// cutlass 3.x
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm90_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
#else
|
||||
// fallback to cutlass 2.x
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
|
||||
out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability.");
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
125
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
Normal file
125
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
Normal file
@@ -0,0 +1,125 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
|
||||
float max_value = 0.0f;
|
||||
unsigned int tid = threadIdx.x;
|
||||
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int grid_size = blockDim.x * gridDim.x;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = num_elements / vec_size;
|
||||
|
||||
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(input + i * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
||||
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
|
||||
float val = static_cast<float>(input[idx]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
|
||||
if (tid == 0) {
|
||||
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_tensor_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
FP8_TYPE* __restrict__ output,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t num_elements) {
|
||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int grid_size = blockDim.x * gridDim.x;
|
||||
const float scale_val = 1.0f / (*scale);
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = num_elements / vec_size;
|
||||
|
||||
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(input + i * vec_size);
|
||||
|
||||
FP8_TYPE output_arr[vec_size];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#ifndef USE_ROCM
|
||||
output_arr[j] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
output[i * vec_size + j] = output_arr[j];
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
||||
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
|
||||
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
output[idx] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
output[idx] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, bool is_static) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
|
||||
const int block_size = 256;
|
||||
const int num_elements = input.numel();
|
||||
const int num_blocks = min((num_elements + block_size - 1) / block_size, 1024);
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(block_size);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
if (is_static == false) {
|
||||
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
|
||||
}
|
||||
|
||||
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
num_elements);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
105
sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
Normal file
105
sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
Normal file
@@ -0,0 +1,105 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(volatile float* smem, const int tid) {
|
||||
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
|
||||
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
|
||||
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
|
||||
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
|
||||
return smem[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_token_group_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
void* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const int group_size,
|
||||
const int num_groups,
|
||||
const float eps,
|
||||
const float fp8_min,
|
||||
const float fp8_max) {
|
||||
const int groups_per_block = 16;
|
||||
const int block_group_id = blockIdx.x * groups_per_block;
|
||||
const int tid = threadIdx.x;
|
||||
const int local_group_id = tid / 16;
|
||||
const int local_tid = tid % 16;
|
||||
|
||||
__shared__ float s_absmax[16][17];
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
if (block_group_id + local_group_id < num_groups) {
|
||||
const T* group_input = input + (block_group_id + local_group_id) * group_size;
|
||||
FP8_TYPE* group_output = static_cast<FP8_TYPE*>(output_q) + (block_group_id + local_group_id) * group_size;
|
||||
float* scale_output = output_s + block_group_id + local_group_id;
|
||||
|
||||
for (int i = local_tid; i < group_size; i += 16) {
|
||||
float val = static_cast<float>(group_input[i]);
|
||||
float abs_val = fabsf(val);
|
||||
local_absmax = fmaxf(local_absmax, abs_val);
|
||||
}
|
||||
|
||||
s_absmax[local_group_id][local_tid] = local_absmax;
|
||||
__syncthreads();
|
||||
|
||||
if (local_tid < 8) {
|
||||
GroupReduceMax(&s_absmax[local_group_id][0], local_tid);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float group_absmax = s_absmax[local_group_id][0];
|
||||
const float y_s = group_absmax / fp8_max;
|
||||
|
||||
if (local_tid == 0) {
|
||||
*scale_output = y_s;
|
||||
}
|
||||
|
||||
for (int i = local_tid; i < group_size; i += 16) {
|
||||
float val = static_cast<float>(group_input[i]);
|
||||
float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max);
|
||||
group_output[i] = FP8_TYPE(q_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
|
||||
const int num_groups = input.numel() / group_size;
|
||||
|
||||
CHECK_EQ(input.numel() % group_size, 0);
|
||||
|
||||
dim3 grid((num_groups + 15) / 16);
|
||||
dim3 block(256);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
output_q.data_ptr(),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
group_size,
|
||||
num_groups,
|
||||
(float)eps,
|
||||
(float)fp8_min,
|
||||
(float)fp8_max);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
111
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Normal file
111
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Normal file
@@ -0,0 +1,111 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_token_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
FP8_TYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const int64_t hidden_dim,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
|
||||
if (token_idx >= num_tokens) return;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int block_dim = blockDim.x;
|
||||
|
||||
const T* token_input = input + token_idx * hidden_dim;
|
||||
FP8_TYPE* token_output = output_q + token_idx * hidden_dim;
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
for (int i = tid; i < hidden_dim; i += block_dim) {
|
||||
float val = static_cast<float>(token_input[i]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
|
||||
__shared__ float block_max;
|
||||
if (tid == 0) {
|
||||
block_max = max_value / FP8_E4M3_MAX;
|
||||
output_s[token_idx] = block_max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_val = 1.0f / block_max;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = hidden_dim / vec_size;
|
||||
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * vec_size);
|
||||
|
||||
FP8_TYPE output_arr[vec_size];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#ifndef USE_ROCM
|
||||
output_arr[j] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
token_output[i * vec_size + j] = output_arr[j];
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
||||
for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) {
|
||||
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(token_input[idx]) * scale_val, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
token_output[idx] = static_cast<FP8_TYPE>(val);
|
||||
#else
|
||||
token_output[idx] = c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
CHECK_INPUT(output_s);
|
||||
|
||||
const auto input_sizes = input.sizes();
|
||||
const int64_t num_tokens = input_sizes[0];
|
||||
const int64_t hidden_dim = input_sizes[1];
|
||||
|
||||
const int block_size = 128;
|
||||
const int num_blocks = num_tokens;
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(block_size);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user