Support fp8 gemm for blackwell (#4558)
This commit is contained in:
@@ -792,6 +792,282 @@ void sm90_fp8_dispatch_shape(
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||
template <
|
||||
typename ElementType,
|
||||
typename OutElementType,
|
||||
typename AccumElementType,
|
||||
typename CTAShape,
|
||||
typename ClusterShape,
|
||||
typename MainloopScheduleType,
|
||||
typename EpilogueScheduleType,
|
||||
typename TileSchedulerType = void,
|
||||
bool WithBias = false>
|
||||
struct DeviceGemmFp8RowwiseSm100 {
|
||||
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||
using TileShape = CTAShape;
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using ElementComputeEpilogue = float;
|
||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
OutElementType,
|
||||
OutElementType,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementType>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementType>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutElementType>::value;
|
||||
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using Compute1MulAdd = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiply_add, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Compute1Mul = cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::multiplies, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute = typename std::conditional_t<
|
||||
WithBias,
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1MulAdd, ScaleA, EVTCompute0, Bias>,
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1Mul, ScaleA, EVTCompute0>>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
// MMA type
|
||||
using ElementAccumulator = AccumElementType;
|
||||
|
||||
// Epilogue types
|
||||
using ElementCompute = float;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
OutElementType,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
ElementType,
|
||||
LayoutA,
|
||||
AlignmentA,
|
||||
ElementType,
|
||||
LayoutB,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
static_assert(
|
||||
std::is_same_v<Descriptor, ScaleA> || std::is_same_v<Descriptor, ScaleB> || std::is_same_v<Descriptor, Bias>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
|
||||
public:
|
||||
static ArgumentType prepare_args(
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias = std::nullopt) {
|
||||
auto a_args = args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
|
||||
if constexpr (WithBias) {
|
||||
auto bias_args = args_from_tensor<Bias, OutElementType>(bias.value());
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
} else {
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmType, bool WithBias>
|
||||
typename GemmType::Gemm::Arguments prepare_sm100_fp8_args(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using Gemm = typename GemmType::Gemm;
|
||||
using ElementT = typename Gemm::ElementA;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
using ElementOutput = typename Gemm::ElementD;
|
||||
using ElementComputeEpilogue = float;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = StrideC;
|
||||
using StrideAux = StrideC;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
|
||||
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
|
||||
StrideAux aux_stride = stride_d;
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{ptr_a, stride_a, ptr_b, stride_b};
|
||||
|
||||
typename GemmKernel::ProblemShape prob_shape = {m, n, k, 1};
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {};
|
||||
|
||||
auto ptr_c = static_cast<ElementOutput*>(out.data_ptr());
|
||||
|
||||
auto prepare_epilogue_args = [&](const c10::optional<torch::Tensor>& bias = c10::nullopt) {
|
||||
if constexpr (WithBias) {
|
||||
TORCH_CHECK(bias.has_value(), "Bias tensor is required but not provided.");
|
||||
return typename GemmKernel::EpilogueArguments{
|
||||
GemmType::prepare_args(scales_a, scales_b, bias.value()), ptr_c, stride_c, ptr_c, stride_d};
|
||||
} else {
|
||||
return typename GemmKernel::EpilogueArguments{
|
||||
GemmType::prepare_args(scales_a, scales_b), ptr_c, stride_c, ptr_c, stride_d};
|
||||
}
|
||||
};
|
||||
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape,
|
||||
mainloop_args,
|
||||
prepare_epilogue_args(bias),
|
||||
hw_info,
|
||||
scheduler};
|
||||
return args;
|
||||
}
|
||||
|
||||
template <typename Gemm, bool WithBias>
|
||||
void launch_sm100_fp8_scaled_mm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
auto args = prepare_sm100_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||
|
||||
typename Gemm::Gemm gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
auto can_implement = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
|
||||
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm100_fp8_dispatch_bias(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using CTAShape = Shape<_256, _128, _64>;
|
||||
using ClusterShape = Shape<_2, _2, _1>;
|
||||
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileSchedulerType = void;
|
||||
|
||||
using ElementInput = cutlass::float_e4m3_t;
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
|
||||
if (bias) {
|
||||
using Gemm = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
return launch_sm100_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
using Gemm = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
return launch_sm100_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void sm100_fp8_dispatch_shape(
|
||||
torch::Tensor& out,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
return sm100_fp8_dispatch_bias<OutType>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
#endif
|
||||
|
||||
torch::Tensor fp8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
@@ -833,6 +1109,17 @@ torch::Tensor fp8_scaled_mm(
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||
if (sm_version >= 100) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
sm100_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
if (sm_version >= 90) {
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
|
||||
Reference in New Issue
Block a user