[sgl-kernel] 1/N Refactor sglang cutlass 3x - gemm fp8 blockwise sm90 (#8913)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
21
sgl-kernel/csrc/cutlass_extensions/common.hpp
Normal file
21
sgl-kernel/csrc/cutlass_extensions/common.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/**
|
||||
* A wrapper for a kernel that is used to guard against compilation on
|
||||
* architectures that will never use the kernel. The purpose of this is to
|
||||
* reduce the size of the compiled binary.
|
||||
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
* into code that will be executed on the device where it is defined.
|
||||
*/
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,62 @@
|
||||
// Adapted from
|
||||
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
// clang-format on
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(
|
||||
torch::Device device,
|
||||
cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = c10::cuda::current_device();
|
||||
hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args, hw_info, scheduler};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
38
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
38
sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/gemm/dispatch_policy.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
||||
// `ScaleGranularityM` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})` along M.
|
||||
template <int ScaleGranularityM = 0>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelTmaWarpSpecializedCooperative {};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||
template <
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1, _1, _1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(
|
||||
cute::
|
||||
is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@@ -0,0 +1,197 @@
|
||||
// Adapted from
|
||||
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
typename SchedulerType,
|
||||
typename OutType,
|
||||
int GroupSizeM_,
|
||||
int GroupSizeN_,
|
||||
int GroupSizeK_,
|
||||
int TileSizeM_ = 128,
|
||||
class ClusterShape = Shape<_1, _2, _1>>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using GroupSizeM = Int<GroupSizeM_>;
|
||||
using GroupSizeN = Int<GroupSizeN_>;
|
||||
using GroupSizeK = Int<GroupSizeK_>;
|
||||
using TileSizeM = Int<TileSizeM_>;
|
||||
|
||||
static_assert(TileSizeM_ % GroupSizeM_ == 0, "TileSizeM must be a multiple of GroupSizeM");
|
||||
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = ElementAB;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = ElementAB;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ScaleTileShape = Shape<_1, _128, _128>;
|
||||
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
// Multiply-accumulate blocking/pipelining details
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for compute
|
||||
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>; // Threadblock-level tile size
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
SchedulerType>;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementA = ElementAB;
|
||||
using ElementB = ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleTileShape = Shape<_1, _128, _128>;
|
||||
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
int n = b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementBlockScale*>(a_scales.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementBlockScale*>(b_scales.data_ptr());
|
||||
|
||||
using StrideA = typename GemmKernel::StrideA;
|
||||
using StrideB = typename GemmKernel::StrideB;
|
||||
using StrideD = typename GemmKernel::StrideD;
|
||||
using StrideC = typename GemmKernel::StrideC;
|
||||
|
||||
StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, a_stride, b_ptr, b_stride, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::TileSchedulerArguments scheduler;
|
||||
|
||||
static constexpr bool UsesStreamKScheduler =
|
||||
cute::is_same_v<typename GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>;
|
||||
|
||||
if constexpr (UsesStreamKScheduler) {
|
||||
using DecompositionMode =
|
||||
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
||||
using ReductionMode =
|
||||
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
||||
|
||||
scheduler.decomposition_mode = DecompositionMode::StreamK;
|
||||
scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
||||
}
|
||||
|
||||
cutlass_gemm_caller<GemmKernel>(a.device(), {m, n, k, 1}, mainloop_args, epilogue_args, scheduler);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
|
||||
if (k > 3 * n) {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<
|
||||
cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
155
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
Executable file → Normal file
155
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
Executable file → Normal file
@@ -30,138 +30,12 @@
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename SchedulerType, typename OutType, typename TileShape, typename ClusterShape>
|
||||
void launch_sm90_fp8_blockwise_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ScaleTileShape = Shape<_1, _128, _128>;
|
||||
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
SchedulerType>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
int n = b.size(1);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
|
||||
auto o_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
auto a_s_ptr = static_cast<ElementBlockScale*>(scales_a.data_ptr());
|
||||
auto b_s_ptr = static_cast<ElementBlockScale*>(scales_b.data_ptr());
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideC stride_c;
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
|
||||
|
||||
LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, stride_a, b_ptr, stride_b, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
|
||||
|
||||
typename Gemm::Arguments args = {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
};
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))
|
||||
|
||||
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
|
||||
}
|
||||
|
||||
template <
|
||||
typename OutType,
|
||||
typename MmaTileShape,
|
||||
@@ -297,27 +171,6 @@ void launch_sm100_fp8_blockwise_scaled_mm(
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_blockwise_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b) {
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
if (k > 3 * n) {
|
||||
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::StreamKScheduler, OutType, TileShape, ClusterShape>(
|
||||
out, a, b, scales_a, scales_b);
|
||||
} else {
|
||||
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::PersistentScheduler, OutType, TileShape, ClusterShape>(
|
||||
out, a, b, scales_a, scales_b);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm100_fp8_blockwise_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
@@ -394,10 +247,10 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
if (sm_version == 90) {
|
||||
torch::Tensor scales_b_contiguous = scales_b.contiguous();
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
|
||||
} else {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
|
||||
}
|
||||
return out_padded.slice(0, 0, original_rows);
|
||||
|
||||
Reference in New Issue
Block a user