support w8a8 fp8 kernel with CUTLASS (#3047)
Co-authored-by: yych0745 <1398089567@qq.com>
This commit is contained in:
164
sgl-kernel/benchmark/bench_fp8_gemm.py
Normal file
164
sgl-kernel/benchmark/bench_fp8_gemm.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
|
||||||
|
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||||
|
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||||
|
|
||||||
|
# Weight Shapes are in the format
|
||||||
|
# ([K, N], TP_SPLIT_DIM)
|
||||||
|
# Example:
|
||||||
|
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 14336, N = 4096
|
||||||
|
# - TP2 : K = 7168, N = 4096
|
||||||
|
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 4096, N = 6144
|
||||||
|
# - TP4 : K = 4096, N = 1536
|
||||||
|
|
||||||
|
# TP1 shapes
|
||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 57344], 1),
|
||||||
|
([28672, 8192], 0),
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-Large-Instruct-2407": [
|
||||||
|
([12288, 14336], 1),
|
||||||
|
([12288, 12288], 0),
|
||||||
|
([12288, 57344], 1),
|
||||||
|
([28672, 12288], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct": [
|
||||||
|
([3584, 4608], 1),
|
||||||
|
([3584, 3584], 0),
|
||||||
|
([3584, 37888], 1),
|
||||||
|
([18944, 3584], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-32B-Instruct": [
|
||||||
|
([5120, 7168], 1),
|
||||||
|
([5120, 5120], 0),
|
||||||
|
([5120, 55296], 1),
|
||||||
|
([27648, 5120], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-72B-Instruct": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 59136], 1),
|
||||||
|
([29568, 8192], 0),
|
||||||
|
],
|
||||||
|
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||||
|
([2048, 3072], 1),
|
||||||
|
([2048, 4096], 1),
|
||||||
|
([2048, 2048], 0),
|
||||||
|
([2048, 576], 0),
|
||||||
|
([2048, 21888], 1),
|
||||||
|
([10944, 2048], 0),
|
||||||
|
([2048, 2816], 1),
|
||||||
|
([1408, 2048], 0),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=[
|
||||||
|
"vllm-fp8-fp16",
|
||||||
|
"vllm-fp8-bf16",
|
||||||
|
"sglang-fp8-fp16",
|
||||||
|
"sglang-fp8-bf16",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"vllm-fp8-fp16",
|
||||||
|
"vllm-fp8-bf16",
|
||||||
|
"sglang-fp8-fp16",
|
||||||
|
"sglang-fp8-bf16",
|
||||||
|
],
|
||||||
|
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
|
||||||
|
ylabel="GB/s",
|
||||||
|
plot_name="fp8 scaled matmul",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K):
|
||||||
|
# M, N, K = batch_size, 4096, 8192
|
||||||
|
M = batch_size
|
||||||
|
a = torch.ones((M, K), device="cuda") * 5.0
|
||||||
|
b = torch.ones((N, K), device="cuda") * 5.0
|
||||||
|
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||||
|
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||||
|
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||||
|
b_fp8 = b_fp8.t()
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
||||||
|
|
||||||
|
if "vllm-fp8" in provider:
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif "sglang-fp8" in provider:
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: sgl_scaled_mm(
|
||||||
|
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
|
||||||
|
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
KN_model_names = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
assert model in WEIGHT_SHAPES
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KN.append(model)
|
||||||
|
KN_model_names.append(KN)
|
||||||
|
return KN_model_names
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||||
|
help="List of models to benchmark",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=[1],
|
||||||
|
help="List of tensor parallel sizes",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
KN_model_names = prepare_shapes(args)
|
||||||
|
for K, N, model_name in KN_model_names:
|
||||||
|
print(f"{model_name} N={N} K={K}: ")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
||||||
@@ -56,6 +56,7 @@ include_dirs = [
|
|||||||
turbomind.resolve(),
|
turbomind.resolve(),
|
||||||
turbomind.resolve() / "src",
|
turbomind.resolve() / "src",
|
||||||
]
|
]
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
"-DNDEBUG",
|
"-DNDEBUG",
|
||||||
f"-DOPERATOR_NAMESPACE={operator_namespace}",
|
f"-DOPERATOR_NAMESPACE={operator_namespace}",
|
||||||
@@ -82,6 +83,7 @@ sources = [
|
|||||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||||
|
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||||
"3rdparty/flashinfer/csrc/activation.cu",
|
"3rdparty/flashinfer/csrc/activation.cu",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
|
|||||||
bmm_fp8,
|
bmm_fp8,
|
||||||
custom_dispose,
|
custom_dispose,
|
||||||
custom_reduce,
|
custom_reduce,
|
||||||
|
fp8_scaled_mm,
|
||||||
fused_add_rmsnorm,
|
fused_add_rmsnorm,
|
||||||
gelu_and_mul,
|
gelu_and_mul,
|
||||||
gelu_tanh_and_mul,
|
gelu_tanh_and_mul,
|
||||||
@@ -27,6 +28,7 @@ __all__ = [
|
|||||||
"bmm_fp8",
|
"bmm_fp8",
|
||||||
"custom_dispose",
|
"custom_dispose",
|
||||||
"custom_reduce",
|
"custom_reduce",
|
||||||
|
"fp8_scaled_mm",
|
||||||
"fused_add_rmsnorm",
|
"fused_add_rmsnorm",
|
||||||
"gelu_and_mul",
|
"gelu_and_mul",
|
||||||
"gelu_tanh_and_mul",
|
"gelu_tanh_and_mul",
|
||||||
|
|||||||
624
sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
Normal file
624
sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
Normal file
@@ -0,0 +1,624 @@
|
|||||||
|
// Adapted from
|
||||||
|
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h
|
||||||
|
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
|
||||||
|
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
#include <cutlass/arch/arch.h>
|
||||||
|
#include <cutlass/arch/memory.h>
|
||||||
|
#include <cutlass/arch/mma.h>
|
||||||
|
#include <cutlass/array.h>
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/epilogue/thread/activation.h>
|
||||||
|
#include <cutlass/epilogue/thread/linear_combination.h>
|
||||||
|
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
|
||||||
|
#include <cutlass/gemm/device/gemm.h>
|
||||||
|
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||||
|
#include <cutlass/gemm/gemm.h>
|
||||||
|
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
|
||||||
|
#include <cutlass/gemm/thread/mma.h>
|
||||||
|
#include <cutlass/layout/matrix.h>
|
||||||
|
#include <cutlass/matrix_coord.h>
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
#include <cutlass/tensor_ref.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <cute/tensor.hpp>
|
||||||
|
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||||
|
#include <cutlass/epilogue/collective/default_epilogue.hpp>
|
||||||
|
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
|
||||||
|
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||||
|
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||||
|
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||||
|
#include <cutlass/util/packed_stride.hpp>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
||||||
|
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CtaShape,
|
||||||
|
typename WarpShape, int Stages, bool WithBias, typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
|
||||||
|
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
|
||||||
|
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
|
||||||
|
struct DeviceGemmFp8RowwiseSm89 {
|
||||||
|
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||||
|
|
||||||
|
using ElementA = ElementType;
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||||
|
|
||||||
|
using ElementB = ElementType;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
|
||||||
|
using ElementC = OutElementType;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||||
|
|
||||||
|
using ElementOutput = OutElementType;
|
||||||
|
using LayoutOutput = cutlass::layout::RowMajor;
|
||||||
|
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||||
|
|
||||||
|
using ElementAccumulator = AccumElementType;
|
||||||
|
using ElementComputeEpilogue = float;
|
||||||
|
using ArchTag = cutlass::arch::Sm89;
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
|
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
|
// Number of epilogue stages in EVT
|
||||||
|
static constexpr int EVTEpilogueStages = 1;
|
||||||
|
|
||||||
|
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<CtaShape, WarpShape, ElementC,
|
||||||
|
AlignmentC, EVTEpilogueStages>;
|
||||||
|
|
||||||
|
// Definition of EVT
|
||||||
|
using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||||
|
|
||||||
|
using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue,
|
||||||
|
Stride<_0, _1, _0>>;
|
||||||
|
using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;
|
||||||
|
|
||||||
|
using ComputeAScale =
|
||||||
|
cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue,
|
||||||
|
Stride<_1, _0, _0>>;
|
||||||
|
using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;
|
||||||
|
|
||||||
|
// With bias
|
||||||
|
using biasSrc =
|
||||||
|
cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
|
||||||
|
using ComputeAScaleWithBias =
|
||||||
|
cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiply_add, ElementC, ElementComputeEpilogue,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using EpilogueAScaleWithBias =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;
|
||||||
|
|
||||||
|
using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||||
|
OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride<int64_t, _1, _0>>;
|
||||||
|
using EpilogueStore =
|
||||||
|
typename cutlass::platform::conditional<WithBias,
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
|
||||||
|
|
||||||
|
using EpilogueOp = EpilogueStore;
|
||||||
|
|
||||||
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||||
|
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB,
|
||||||
|
cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator,
|
||||||
|
ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp,
|
||||||
|
ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm, bool WithBias>
|
||||||
|
typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
using ElementT = typename Gemm::ElementA;
|
||||||
|
using ElementOutput = typename Gemm::ElementD;
|
||||||
|
using ElementComputeEpilogue = float;
|
||||||
|
|
||||||
|
int32_t m = a.size(0);
|
||||||
|
int32_t n = b.size(1);
|
||||||
|
int32_t k = a.size(1);
|
||||||
|
|
||||||
|
int64_t lda = a.stride(0);
|
||||||
|
int64_t ldb = b.stride(1);
|
||||||
|
int64_t ldc = out.stride(0);
|
||||||
|
|
||||||
|
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
|
||||||
|
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
|
||||||
|
ElementOutput const* ptr_bias = nullptr;
|
||||||
|
if constexpr (WithBias) {
|
||||||
|
TORCH_CHECK(bias.has_value())
|
||||||
|
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
|
||||||
|
}
|
||||||
|
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
|
||||||
|
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
|
||||||
|
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
|
||||||
|
|
||||||
|
typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode
|
||||||
|
{m, n, k}, // Problem size
|
||||||
|
1, // Split-k factor
|
||||||
|
{}, // Epilogue args
|
||||||
|
ptr_a, // a pointer
|
||||||
|
ptr_b, // b pointer
|
||||||
|
nullptr, // c pointer (unused)
|
||||||
|
nullptr, // d pointer (unused)
|
||||||
|
m * k, // batch stride a (unused)
|
||||||
|
n * k, // batch stride b (unused)
|
||||||
|
m * n, // batch stride c (unused)
|
||||||
|
m * n, // batch stride d (unused)
|
||||||
|
lda, // stride a
|
||||||
|
ldb, // stride b
|
||||||
|
ldc, // stride c (unused)
|
||||||
|
ldc); // stride d (unused)
|
||||||
|
if constexpr (WithBias) {
|
||||||
|
args.epilogue = {{
|
||||||
|
{
|
||||||
|
{}, // Accumulator
|
||||||
|
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||||
|
{} // Multiplies
|
||||||
|
},
|
||||||
|
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||||
|
{ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
|
||||||
|
{} // Multiplies
|
||||||
|
},
|
||||||
|
{ptr_d, {n, _1{}, _0{}}}};
|
||||||
|
} else {
|
||||||
|
args.epilogue = {{
|
||||||
|
{
|
||||||
|
{}, // Accumulator
|
||||||
|
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
|
||||||
|
{} // Multiplies
|
||||||
|
},
|
||||||
|
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
|
||||||
|
{} // Multiplies
|
||||||
|
},
|
||||||
|
{ptr_d, {n, _1{}, _0{}}}};
|
||||||
|
}
|
||||||
|
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Gemm, bool WithBias>
|
||||||
|
void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
Gemm gemm_op;
|
||||||
|
|
||||||
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
|
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||||
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
|
||||||
|
auto can_implement = gemm_op.can_implement(args);
|
||||||
|
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
|
||||||
|
|
||||||
|
auto status = gemm_op(args, workspace.data_ptr(), stream);
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
|
||||||
|
void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
using ElementInput = cutlass::float_e4m3_t;
|
||||||
|
using ElementOutput = OutType;
|
||||||
|
using AccumElementType = float;
|
||||||
|
if (bias) {
|
||||||
|
using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape,
|
||||||
|
Stages, true>::Gemm;
|
||||||
|
return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape,
|
||||||
|
Stages, false>::Gemm;
|
||||||
|
return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutType>
|
||||||
|
void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
uint32_t const m = a.size(0);
|
||||||
|
uint32_t const n = out.size(1);
|
||||||
|
|
||||||
|
if (m == 1) {
|
||||||
|
if (n <= 8192) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 16) {
|
||||||
|
// M in (1, 16]
|
||||||
|
if (n <= 8192) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else if (n <= 16384) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 64) {
|
||||||
|
// M in (16, 64]
|
||||||
|
if (n <= 16384) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 128) {
|
||||||
|
// M in (64, 128]
|
||||||
|
if (n <= 8192) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else if (n <= 16384) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 256) {
|
||||||
|
// M in (128, 256]
|
||||||
|
if (n <= 8192) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else if (n <= 16384) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 128>,
|
||||||
|
cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else if (m <= 512) {
|
||||||
|
// M in (256, 512)
|
||||||
|
if (n <= 16384) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// M in (512, inf)
|
||||||
|
if (n <= 8192) {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CTAShape,
|
||||||
|
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
|
||||||
|
typename TileSchedulerType = void, bool WithBias = false>
|
||||||
|
struct DeviceGemmFp8RowwiseSm90 {
|
||||||
|
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
|
||||||
|
|
||||||
|
// A matrix configuration
|
||||||
|
using ElementA = ElementType; // Element type for A matrix operand
|
||||||
|
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||||
|
static constexpr int AlignmentA =
|
||||||
|
128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A
|
||||||
|
// matrix in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// B matrix configuration
|
||||||
|
using ElementB = ElementType; // Element type for B matrix operand
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||||
|
static constexpr int AlignmentB =
|
||||||
|
128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B
|
||||||
|
// matrix in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// C/D matrix configuration
|
||||||
|
using ElementC = void; // Element type for C matrix operands
|
||||||
|
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
|
||||||
|
static constexpr int AlignmentC =
|
||||||
|
128 / cutlass::sizeof_bits<OutElementType>::value; // Memory access granularity/alignment of C matrices in
|
||||||
|
// units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// Output matrix configuration
|
||||||
|
using ElementOutput = OutElementType; // Element type for output matrix operands
|
||||||
|
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
|
||||||
|
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||||
|
|
||||||
|
// // Auxiliary matrix configuration and other fusion types
|
||||||
|
// using ElementBias = float;
|
||||||
|
|
||||||
|
// Multiply-accumulate blocking/pipelining details
|
||||||
|
using ElementAccumulator = AccumElementType; // Element type for internal accumulation
|
||||||
|
using ElementCompute = float; // Element type for compute
|
||||||
|
using ElementComputeEpilogue = float;
|
||||||
|
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||||
|
using TileShape = CTAShape; // Threadblock-level tile size
|
||||||
|
|
||||||
|
static constexpr bool PONG = false;
|
||||||
|
static constexpr bool FAST_ACCUM = true;
|
||||||
|
static constexpr bool USE_BIAS = false;
|
||||||
|
|
||||||
|
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized
|
||||||
|
// based on the tile size
|
||||||
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default
|
||||||
|
// setting in the Collective Builder
|
||||||
|
// Implement rowwise scaling epilogue.
|
||||||
|
using XScale =
|
||||||
|
cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
|
||||||
|
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||||
|
|
||||||
|
using WScale =
|
||||||
|
cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
|
||||||
|
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||||
|
|
||||||
|
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput,
|
||||||
|
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||||
|
|
||||||
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||||
|
|
||||||
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies,
|
||||||
|
ElementComputeEpilogue, // First stage output type.
|
||||||
|
ElementComputeEpilogue, // First stage input types.
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementOutput,
|
||||||
|
ElementComputeEpilogue, // Second stage input types.
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||||
|
|
||||||
|
// With bias
|
||||||
|
using ComputeWithBias =
|
||||||
|
cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementOutput, ElementComputeEpilogue,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
|
||||||
|
|
||||||
|
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
|
||||||
|
|
||||||
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC,
|
||||||
|
AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized,
|
||||||
|
EpilogueEVT>::CollectiveOp;
|
||||||
|
|
||||||
|
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||||
|
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||||
|
using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
|
||||||
|
using SlowAccum = DefaultSchedule;
|
||||||
|
using FastAccum = FastPongSchedule; // Default apply Pingpong
|
||||||
|
|
||||||
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
|
||||||
|
TileShape, ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
MainloopScheduleType>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
|
||||||
|
CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm, bool WithBias>
|
||||||
|
typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
using ElementT = typename Gemm::ElementA;
|
||||||
|
using ElementOutput = typename Gemm::ElementD;
|
||||||
|
using ElementComputeEpilogue = float;
|
||||||
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||||
|
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||||
|
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||||
|
|
||||||
|
int32_t m = a.size(0);
|
||||||
|
int32_t n = b.size(1);
|
||||||
|
int32_t k = a.size(1);
|
||||||
|
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
|
||||||
|
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
|
||||||
|
ElementOutput const* ptr_bias = nullptr;
|
||||||
|
if constexpr (WithBias) {
|
||||||
|
TORCH_CHECK(bias.has_value())
|
||||||
|
ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
|
||||||
|
}
|
||||||
|
ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
|
||||||
|
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
|
||||||
|
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
|
||||||
|
|
||||||
|
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
|
||||||
|
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
|
||||||
|
StrideC stride_c;
|
||||||
|
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
|
||||||
|
typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
{m, n, k, 1},
|
||||||
|
{ptr_a, stride_a, ptr_b, stride_b},
|
||||||
|
{{}, // epilogue.thread
|
||||||
|
nullptr,
|
||||||
|
stride_c,
|
||||||
|
ptr_d,
|
||||||
|
stride_d}};
|
||||||
|
if constexpr (WithBias) {
|
||||||
|
args.epilogue.thread = {
|
||||||
|
{ptr_scales_a},
|
||||||
|
{
|
||||||
|
{ptr_scales_b},
|
||||||
|
{}, // Accumulator
|
||||||
|
{} // Multiplies
|
||||||
|
},
|
||||||
|
{ptr_bias},
|
||||||
|
{}, // Multiplies
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
args.epilogue.thread = {
|
||||||
|
{ptr_scales_a},
|
||||||
|
{
|
||||||
|
{ptr_scales_b},
|
||||||
|
{}, // Accumulator
|
||||||
|
{} // Multiplies
|
||||||
|
},
|
||||||
|
{}, // Multiplies
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Gemm, bool WithBias>
|
||||||
|
void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
Gemm gemm_op;
|
||||||
|
|
||||||
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
|
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||||
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
|
||||||
|
auto can_implement = gemm_op.can_implement(args);
|
||||||
|
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
|
||||||
|
|
||||||
|
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
|
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess)
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutType, typename CTAShape, typename ClusterShape, typename MainloopScheduleType,
|
||||||
|
typename TileSchedulerType>
|
||||||
|
void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias, bool fast_accum = true,
|
||||||
|
bool use_persistent = false) {
|
||||||
|
using ElementInput = cutlass::float_e4m3_t;
|
||||||
|
using ElementOutput = OutType;
|
||||||
|
using AccumElementType = float;
|
||||||
|
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
|
||||||
|
if (bias) {
|
||||||
|
using Gemm =
|
||||||
|
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
|
||||||
|
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, true>::Gemm;
|
||||||
|
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
using Gemm =
|
||||||
|
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
|
||||||
|
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, false>::Gemm;
|
||||||
|
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutType>
|
||||||
|
void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
uint32_t const m = a.size(0);
|
||||||
|
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||||
|
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
|
||||||
|
using BasicTileScheduler = void;
|
||||||
|
if (m <= 1) {
|
||||||
|
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>, FastBasicScheduler,
|
||||||
|
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
if (m <= 64) {
|
||||||
|
// m in [1, 64]
|
||||||
|
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>, FastPingpongScheduler,
|
||||||
|
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else if (m <= 256) {
|
||||||
|
// m in (64, 256]
|
||||||
|
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
|
||||||
|
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else if (m <= 1024) {
|
||||||
|
// m in (256, 1024]
|
||||||
|
return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
|
||||||
|
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
// m in (1024, inf)
|
||||||
|
return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>, FastPingpongScheduler,
|
||||||
|
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||||
|
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||||
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
|
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
|
||||||
|
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
|
||||||
|
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
|
||||||
|
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
|
||||||
|
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
|
||||||
|
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
|
||||||
|
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
|
||||||
|
|
||||||
|
TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0,
|
||||||
|
"mat_a must be multiple of 16 bytes for memory alignment");
|
||||||
|
TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0,
|
||||||
|
"mat_b must be multiple of 16 bytes for memory alignment");
|
||||||
|
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
|
||||||
|
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
|
||||||
|
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
|
||||||
|
|
||||||
|
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
|
||||||
|
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
|
||||||
|
TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
|
||||||
|
TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
|
||||||
|
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
|
||||||
|
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");
|
||||||
|
|
||||||
|
if (bias) {
|
||||||
|
TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
|
||||||
|
TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
|
||||||
|
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
|
||||||
|
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");
|
||||||
|
|
||||||
|
auto sm_version = getSMVersion();
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
if (sm_version >= 90) {
|
||||||
|
if (out_dtype == torch::kBFloat16) {
|
||||||
|
sm90_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
sm90_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
|
||||||
|
if (sm_version == 89) {
|
||||||
|
if (out_dtype == torch::kBFloat16) {
|
||||||
|
sm89_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
} else {
|
||||||
|
sm89_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
|
||||||
|
}
|
||||||
@@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
|
|||||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||||
const c10::optional<torch::Tensor>& bias);
|
const c10::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
|
// fp8_scaled_mm
|
||||||
|
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||||
|
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||||
|
const c10::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
// lightning_attention_decode
|
// lightning_attention_decode
|
||||||
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||||
|
|||||||
@@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||||
|
return torch.ops.sgl_kernels.fp8_scaled_mm(
|
||||||
|
mat_a,
|
||||||
|
mat_b,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
out_dtype,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||||
torch.ops.sgl_kernels.lightning_attention_decode(
|
torch.ops.sgl_kernels.lightning_attention_decode(
|
||||||
q, k, v, past_kv, slope, output, new_kv
|
q, k, v, past_kv, slope, output, new_kv
|
||||||
|
|||||||
@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
|||||||
"bias) -> Tensor");
|
"bias) -> Tensor");
|
||||||
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
|
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
|
||||||
|
|
||||||
|
// fp8_scaled_mm
|
||||||
|
m.def(
|
||||||
|
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||||
|
"bias) -> Tensor");
|
||||||
|
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
|
||||||
|
|
||||||
// lightning_attention_decode
|
// lightning_attention_decode
|
||||||
m.def(
|
m.def(
|
||||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||||
|
|||||||
67
sgl-kernel/tests/test_fp8_gemm.py
Normal file
67
sgl-kernel/tests/test_fp8_gemm.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import fp8_scaled_mm
|
||||||
|
|
||||||
|
|
||||||
|
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
|
||||||
|
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
|
||||||
|
|
||||||
|
o = o.to(torch.float32)
|
||||||
|
temp1 = o * scale_a.view(-1, 1)
|
||||||
|
temp2 = temp1 * scale_b.view(1, -1)
|
||||||
|
final = temp2.to(out_dtype)
|
||||||
|
if bias is not None:
|
||||||
|
final = final + bias.view(1, -1)
|
||||||
|
|
||||||
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
class TestFp8Gemm(unittest.TestCase):
|
||||||
|
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
a_fp32 = (
|
||||||
|
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||||
|
)
|
||||||
|
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
b_fp32 = (
|
||||||
|
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||||
|
)
|
||||||
|
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
|
||||||
|
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
|
||||||
|
if with_bias:
|
||||||
|
bias = torch.randn((N,), device=device, dtype=out_dtype)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
|
||||||
|
b_fp8 = b_fp8.t()
|
||||||
|
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
|
||||||
|
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
|
||||||
|
rtol = 0.02
|
||||||
|
atol = 1
|
||||||
|
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||||
|
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||||
|
|
||||||
|
def test_accuracy(self):
|
||||||
|
Ms = [1, 128, 512, 1024, 4096]
|
||||||
|
Ns = [16, 128, 512, 1024, 4096]
|
||||||
|
Ks = [512, 1024, 4096, 8192, 16384]
|
||||||
|
bias_opts = [True, False]
|
||||||
|
out_dtypes = [torch.bfloat16, torch.float16]
|
||||||
|
for M in Ms:
|
||||||
|
for N in Ns:
|
||||||
|
for K in Ks:
|
||||||
|
for with_bias in bias_opts:
|
||||||
|
for out_dtype in out_dtypes:
|
||||||
|
self._test_accuracy_once(
|
||||||
|
M, N, K, with_bias, out_dtype, "cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user