[sgl-kernel] 1/N Refactor sglang cutlass 3x - gemm fp8 blockwise sm90 (#8913)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-08-15 01:55:54 +08:00
committed by GitHub
parent 1fea998a45
commit 432f2053dd
5 changed files with 322 additions and 151 deletions

155
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu Executable file → Normal file
View File

@@ -30,138 +30,12 @@
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh"
#include "cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh"
#include "utils.h"
using namespace cute;
template <typename SchedulerType, typename OutType, typename TileShape, typename ClusterShape>
void launch_sm90_fp8_blockwise_scaled_mm(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
using ElementA = cutlass::float_e4m3_t;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = cutlass::float_e4m3_t;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
constexpr int AlignmentD = AlignmentC;
using ScaleTileShape = Shape<_1, _128, _128>;
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueSchedule,
StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
SchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op;
int m = a.size(0);
int k = a.size(1);
int n = b.size(1);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
auto o_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_s_ptr = static_cast<ElementBlockScale*>(scales_a.data_ptr());
auto b_s_ptr = static_cast<ElementBlockScale*>(scales_b.data_ptr());
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, stride_a, b_ptr, stride_b, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
typename Gemm::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
mainloop_args,
epilogue_args,
};
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
}
template <
typename OutType,
typename MmaTileShape,
@@ -297,27 +171,6 @@ void launch_sm100_fp8_blockwise_scaled_mm(
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
}
template <typename OutType>
void sm90_fp8_blockwise_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
auto k = a.size(1);
auto n = b.size(1);
if (k > 3 * n) {
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::StreamKScheduler, OutType, TileShape, ClusterShape>(
out, a, b, scales_a, scales_b);
} else {
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::PersistentScheduler, OutType, TileShape, ClusterShape>(
out, a, b, scales_a, scales_b);
}
}
template <typename OutType>
void sm100_fp8_blockwise_dispatch_shape(
torch::Tensor& out,
@@ -394,10 +247,10 @@ torch::Tensor fp8_blockwise_scaled_mm(
if (sm_version == 90) {
torch::Tensor scales_b_contiguous = scales_b.contiguous();
if (out_dtype == torch::kBFloat16) {
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
} else {
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
}
return out_padded.slice(0, 0, original_rows);