[Perf] Tunings for SM100 FP8 CUTLASS kernel (#8818)
This commit is contained in:
@@ -1,10 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
|
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
|
||||||
|
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||||
|
|
||||||
@@ -69,6 +71,21 @@ WEIGHT_SHAPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sglang_scaled_fp8_quant(
|
||||||
|
input: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
fp8_type_: torch.dtype = torch.float8_e4m3fn
|
||||||
|
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||||
|
is_static = True
|
||||||
|
if scale is None:
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
is_static = False
|
||||||
|
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
|
||||||
|
|
||||||
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["batch_size"],
|
x_names=["batch_size"],
|
||||||
@@ -100,19 +117,22 @@ def benchmark(batch_size, provider, N, K):
|
|||||||
b = torch.ones((N, K), device="cuda") * 5.0
|
b = torch.ones((N, K), device="cuda") * 5.0
|
||||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
|
||||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
|
||||||
b_fp8 = b_fp8.t()
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
||||||
|
|
||||||
if "vllm-fp8" in provider:
|
if "vllm-fp8" in provider:
|
||||||
|
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||||
|
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||||
|
b_fp8 = b_fp8.t()
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
elif "sglang-fp8" in provider:
|
elif "sglang-fp8" in provider:
|
||||||
|
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
|
||||||
|
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
|
||||||
|
b_fp8 = b_fp8.t()
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: sgl_scaled_mm(
|
lambda: sgl_scaled_mm(
|
||||||
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ limitations under the License.
|
|||||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||||
#include <cutlass/util/packed_stride.hpp>
|
#include <cutlass/util/packed_stride.hpp>
|
||||||
|
|
||||||
|
#include "math.hpp"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
@@ -1019,8 +1020,18 @@ void sm100_fp8_dispatch_bias(
|
|||||||
const torch::Tensor& scales_a,
|
const torch::Tensor& scales_a,
|
||||||
const torch::Tensor& scales_b,
|
const torch::Tensor& scales_b,
|
||||||
const c10::optional<torch::Tensor>& bias) {
|
const c10::optional<torch::Tensor>& bias) {
|
||||||
using CTAShape = Shape<_256, _128, _64>;
|
using CTAShapeDefault = Shape<_256, _128, _64>;
|
||||||
using ClusterShape = Shape<_2, _2, _1>;
|
using ClusterShapeDefault = Shape<_2, _2, _1>;
|
||||||
|
|
||||||
|
using CTAShape256 = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape256 = Shape<_2, _1, _1>;
|
||||||
|
|
||||||
|
using CTAShape64 = Shape<_64, _64, _128>;
|
||||||
|
using ClusterShape64 = Shape<_1, _1, _1>;
|
||||||
|
|
||||||
|
using CTAShape16 = Shape<_64, _64, _128>;
|
||||||
|
using ClusterShape16 = Shape<_1, _4, _1>;
|
||||||
|
|
||||||
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
|
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
using TileSchedulerType = void;
|
using TileSchedulerType = void;
|
||||||
@@ -1029,30 +1040,121 @@ void sm100_fp8_dispatch_bias(
|
|||||||
using ElementOutput = OutType;
|
using ElementOutput = OutType;
|
||||||
using AccumElementType = float;
|
using AccumElementType = float;
|
||||||
|
|
||||||
|
// Gemm type with bias
|
||||||
|
using BiasGemmDefault = DeviceGemmFp8RowwiseSm100<
|
||||||
|
ElementInput,
|
||||||
|
ElementOutput,
|
||||||
|
AccumElementType,
|
||||||
|
CTAShapeDefault,
|
||||||
|
ClusterShapeDefault,
|
||||||
|
MainloopScheduleType,
|
||||||
|
EpilogueScheduleType,
|
||||||
|
TileSchedulerType,
|
||||||
|
true>;
|
||||||
|
using BiasGemm256 = DeviceGemmFp8RowwiseSm100<
|
||||||
|
ElementInput,
|
||||||
|
ElementOutput,
|
||||||
|
AccumElementType,
|
||||||
|
CTAShape256,
|
||||||
|
ClusterShape256,
|
||||||
|
MainloopScheduleType,
|
||||||
|
EpilogueScheduleType,
|
||||||
|
TileSchedulerType,
|
||||||
|
true>;
|
||||||
|
using BiasGemm64 = DeviceGemmFp8RowwiseSm100<
|
||||||
|
ElementInput,
|
||||||
|
ElementOutput,
|
||||||
|
AccumElementType,
|
||||||
|
CTAShape64,
|
||||||
|
ClusterShape64,
|
||||||
|
MainloopScheduleType,
|
||||||
|
EpilogueScheduleType,
|
||||||
|
TileSchedulerType,
|
||||||
|
true>;
|
||||||
|
using BiasGemm16 = DeviceGemmFp8RowwiseSm100<
|
||||||
|
ElementInput,
|
||||||
|
ElementOutput,
|
||||||
|
AccumElementType,
|
||||||
|
CTAShape16,
|
||||||
|
ClusterShape16,
|
||||||
|
MainloopScheduleType,
|
||||||
|
EpilogueScheduleType,
|
||||||
|
TileSchedulerType,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
// Gemm type without bias
|
||||||
|
using GemmDefault = DeviceGemmFp8RowwiseSm100<
|
||||||
|
ElementInput,
|
||||||
|
ElementOutput,
|
||||||
|
AccumElementType,
|
||||||
|
CTAShapeDefault,
|
||||||
|
ClusterShapeDefault,
|
||||||
|
MainloopScheduleType,
|
||||||
|
EpilogueScheduleType,
|
||||||
|
TileSchedulerType,
|
||||||
|
false>;
|
||||||
|
using Gemm256 = DeviceGemmFp8RowwiseSm100<
|
||||||
|
ElementInput,
|
||||||
|
ElementOutput,
|
||||||
|
AccumElementType,
|
||||||
|
CTAShape256,
|
||||||
|
ClusterShape256,
|
||||||
|
MainloopScheduleType,
|
||||||
|
EpilogueScheduleType,
|
||||||
|
TileSchedulerType,
|
||||||
|
false>;
|
||||||
|
using Gemm64 = DeviceGemmFp8RowwiseSm100<
|
||||||
|
ElementInput,
|
||||||
|
ElementOutput,
|
||||||
|
AccumElementType,
|
||||||
|
CTAShape64,
|
||||||
|
ClusterShape64,
|
||||||
|
MainloopScheduleType,
|
||||||
|
EpilogueScheduleType,
|
||||||
|
TileSchedulerType,
|
||||||
|
false>;
|
||||||
|
using Gemm16 = DeviceGemmFp8RowwiseSm100<
|
||||||
|
ElementInput,
|
||||||
|
ElementOutput,
|
||||||
|
AccumElementType,
|
||||||
|
CTAShape16,
|
||||||
|
ClusterShape16,
|
||||||
|
MainloopScheduleType,
|
||||||
|
EpilogueScheduleType,
|
||||||
|
TileSchedulerType,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
// next power of 2 (minimum 16)
|
||||||
|
uint32_t const m = a.size(0);
|
||||||
|
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||||
|
|
||||||
if (bias) {
|
if (bias) {
|
||||||
using Gemm = DeviceGemmFp8RowwiseSm100<
|
if (mp2 <= 16) {
|
||||||
ElementInput,
|
// m in [1, 16]
|
||||||
ElementOutput,
|
return launch_sm100_fp8_scaled_mm<BiasGemm16, true>(out, a, b, scales_a, scales_b, bias);
|
||||||
AccumElementType,
|
} else if (mp2 <= 64) {
|
||||||
CTAShape,
|
// m in (16, 64]
|
||||||
ClusterShape,
|
return launch_sm100_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
|
||||||
MainloopScheduleType,
|
} else if (mp2 <= 256) {
|
||||||
EpilogueScheduleType,
|
// m in (64, 256]
|
||||||
TileSchedulerType,
|
return launch_sm100_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
|
||||||
true>;
|
} else {
|
||||||
return launch_sm100_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
// m in (256, inf]
|
||||||
|
return launch_sm100_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
using Gemm = DeviceGemmFp8RowwiseSm100<
|
if (mp2 <= 16) {
|
||||||
ElementInput,
|
// m in [1, 16]
|
||||||
ElementOutput,
|
return launch_sm100_fp8_scaled_mm<Gemm16, false>(out, a, b, scales_a, scales_b, bias);
|
||||||
AccumElementType,
|
} else if (mp2 <= 64) {
|
||||||
CTAShape,
|
// m in (16, 64]
|
||||||
ClusterShape,
|
return launch_sm100_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
|
||||||
MainloopScheduleType,
|
} else if (mp2 <= 256) {
|
||||||
EpilogueScheduleType,
|
// m in (64, 256]
|
||||||
TileSchedulerType,
|
return launch_sm100_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
|
||||||
false>;
|
} else {
|
||||||
return launch_sm100_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
return launch_sm100_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
28
sgl-kernel/csrc/gemm/math.hpp
Normal file
28
sgl-kernel/csrc/gemm/math.hpp
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <climits>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||||
|
if (num <= 1) return num;
|
||||||
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename A, typename B>
|
||||||
|
static inline constexpr auto div_ceil(A a, B b) {
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round a down to the next multiple of b. The caller is responsible for making
|
||||||
|
// sure that b is non-zero
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr T round_to_previous_multiple_of(T a, T b) {
|
||||||
|
return a % b == 0 ? a : (a / b) * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round a up to the next multiple of b. The caller is responsible for making
|
||||||
|
// sure that b is non-zero
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr T round_to_next_multiple_of(T a, T b) {
|
||||||
|
return a % b == 0 ? a : ((a / b) + 1) * b;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user