[Perf] Tunings for SM100 FP8 CUTLASS kernel (#8818)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||
|
||||
@@ -69,6 +71,21 @@ WEIGHT_SHAPES = {
|
||||
}
|
||||
|
||||
|
||||
def sglang_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_type_: torch.dtype = torch.float8_e4m3fn
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
is_static = True
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
is_static = False
|
||||
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
@@ -100,19 +117,22 @@ def benchmark(batch_size, provider, N, K):
|
||||
b = torch.ones((N, K), device="cuda") * 5.0
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
||||
|
||||
if "vllm-fp8" in provider:
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif "sglang-fp8" in provider:
|
||||
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
||||
|
||||
@@ -48,6 +48,7 @@ limitations under the License.
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include "math.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace cute;
|
||||
@@ -1019,8 +1020,18 @@ void sm100_fp8_dispatch_bias(
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const c10::optional<torch::Tensor>& bias) {
|
||||
using CTAShape = Shape<_256, _128, _64>;
|
||||
using ClusterShape = Shape<_2, _2, _1>;
|
||||
using CTAShapeDefault = Shape<_256, _128, _64>;
|
||||
using ClusterShapeDefault = Shape<_2, _2, _1>;
|
||||
|
||||
using CTAShape256 = Shape<_128, _128, _128>;
|
||||
using ClusterShape256 = Shape<_2, _1, _1>;
|
||||
|
||||
using CTAShape64 = Shape<_64, _64, _128>;
|
||||
using ClusterShape64 = Shape<_1, _1, _1>;
|
||||
|
||||
using CTAShape16 = Shape<_64, _64, _128>;
|
||||
using ClusterShape16 = Shape<_1, _4, _1>;
|
||||
|
||||
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileSchedulerType = void;
|
||||
@@ -1029,30 +1040,121 @@ void sm100_fp8_dispatch_bias(
|
||||
using ElementOutput = OutType;
|
||||
using AccumElementType = float;
|
||||
|
||||
// Gemm type with bias
|
||||
using BiasGemmDefault = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShapeDefault,
|
||||
ClusterShapeDefault,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
using BiasGemm256 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape256,
|
||||
ClusterShape256,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
using BiasGemm64 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape64,
|
||||
ClusterShape64,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
using BiasGemm16 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape16,
|
||||
ClusterShape16,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
|
||||
// Gemm type without bias
|
||||
using GemmDefault = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShapeDefault,
|
||||
ClusterShapeDefault,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
using Gemm256 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape256,
|
||||
ClusterShape256,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
using Gemm64 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape64,
|
||||
ClusterShape64,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
using Gemm16 = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape16,
|
||||
ClusterShape16,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
|
||||
// next power of 2 (minimum 16)
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (bias) {
|
||||
using Gemm = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
true>;
|
||||
return launch_sm100_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
return launch_sm100_fp8_scaled_mm<BiasGemm16, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (16, 64]
|
||||
return launch_sm100_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (64, 256]
|
||||
return launch_sm100_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
// m in (256, inf]
|
||||
return launch_sm100_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
} else {
|
||||
using Gemm = DeviceGemmFp8RowwiseSm100<
|
||||
ElementInput,
|
||||
ElementOutput,
|
||||
AccumElementType,
|
||||
CTAShape,
|
||||
ClusterShape,
|
||||
MainloopScheduleType,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerType,
|
||||
false>;
|
||||
return launch_sm100_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
return launch_sm100_fp8_scaled_mm<Gemm16, false>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (16, 64]
|
||||
return launch_sm100_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (64, 256]
|
||||
return launch_sm100_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
|
||||
} else {
|
||||
return launch_sm100_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
28
sgl-kernel/csrc/gemm/math.hpp
Normal file
28
sgl-kernel/csrc/gemm/math.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
static inline constexpr auto div_ceil(A a, B b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
// Round a down to the next multiple of b. The caller is responsible for making
|
||||
// sure that b is non-zero
|
||||
template <typename T>
|
||||
inline constexpr T round_to_previous_multiple_of(T a, T b) {
|
||||
return a % b == 0 ? a : (a / b) * b;
|
||||
}
|
||||
|
||||
// Round a up to the next multiple of b. The caller is responsible for making
|
||||
// sure that b is non-zero
|
||||
template <typename T>
|
||||
inline constexpr T round_to_next_multiple_of(T a, T b) {
|
||||
return a % b == 0 ? a : ((a / b) + 1) * b;
|
||||
}
|
||||
Reference in New Issue
Block a user