diff --git a/python/sglang/srt/bench_utils.py b/python/sglang/srt/bench_utils.py index ea400bfa8..e9f7fcbb4 100644 --- a/python/sglang/srt/bench_utils.py +++ b/python/sglang/srt/bench_utils.py @@ -1,5 +1,4 @@ import os -import re import sys from contextlib import nullcontext @@ -109,8 +108,7 @@ def bench_kineto( if not with_multiple_kernels: for name in kernel_names: assert ( - sum([int(re.search(name, line) is not None) for line in prof_lines]) - == 1 + sum([name in line for line in prof_lines]) == 1 ), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})" # Save chrome traces @@ -124,7 +122,7 @@ def bench_kineto( total_time = 0 total_num = 0 for line in prof_lines: - if re.search(name, line) is not None: + if name in line: time_str = line.split()[-2] num_str = line.split()[-1] for unit, scale in units.items(): diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 9c30dc060..f0512365b 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -43,17 +43,11 @@ _is_cpu = is_cpu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: - from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8 - - # Temporary - try: - from sgl_kernel import sgl_per_token_group_quant_8bit - - enable_sgl_per_token_group_quant_8bit = True - except ImportError: - from sgl_kernel import sgl_per_token_group_quant_fp8 - - enable_sgl_per_token_group_quant_8bit = False + from sgl_kernel import ( + sgl_per_tensor_quant_fp8, + sgl_per_token_group_quant_fp8, + sgl_per_token_quant_fp8, + ) if _is_hip: if _use_aiter: @@ -502,24 +496,9 @@ def sglang_per_token_group_quant_fp8( ) if x.shape[0] > 0: - # Temporary - if enable_sgl_per_token_group_quant_8bit: - sgl_per_token_group_quant_8bit( - x, - x_q, - x_s, - group_size, - eps, - fp8_min, - fp8_max, - scale_ue8m0, - fuse_silu_and_mul, - masked_m, - ) - else: - sgl_per_token_group_quant_fp8( - x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 - ) + sgl_per_token_group_quant_fp8( + x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 + ) return x_q, x_s diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py index 826d16e3c..7c6c3dbd4 100644 --- a/python/sglang/srt/layers/quantization/int8_kernel.py +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -12,13 +12,7 @@ from sglang.srt.utils import get_device_name, is_cuda _is_cuda = is_cuda() if _is_cuda: - # Temporary - try: - from sgl_kernel import sgl_per_token_group_quant_8bit - except ImportError: - from sgl_kernel import ( - sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit, - ) + from sgl_kernel import sgl_per_token_group_quant_int8 logger = logging.getLogger(__name__) @@ -210,7 +204,7 @@ def sglang_per_token_group_quant_int8( dtype=torch.float32, ) - sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max) + sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) return x_q, x_s diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py index 7237312ce..3f37a3248 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -1,12 +1,10 @@ import itertools -import os import time from functools import partial from pathlib import Path import torch import triton -from sgl_kernel.test_utils import create_per_token_group_quant_test_data from sglang.srt.bench_utils import bench_kineto from sglang.srt.layers.quantization.fp8_kernel import ( @@ -21,231 +19,78 @@ from sglang.srt.utils import is_hip _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn -mode_concentrated = os.environ.get("SGLANG_BENCH_MODE", "") == "concentrated" -if int(os.environ.get("SGLANG_NSYS_PROFILING", "0")): - # configs = [[ - # 768, - # 16384, - # 128, - # None, - # fp8_type_, - # dict( - # column_major_scales=True, - # scale_tma_aligned=True, - # scale_ue8m0=True, - # fuse_silu_and_mul=False, - # masked_layout_mode=None, - # ), - # ]] - configs = [ - [ - 768 * 8, - 2048, - 128, - 48, - fp8_type_, - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - # masked_layout_mode=None, - masked_layout_mode="balanced", - # masked_layout_mode="extreme", - ), - ] - ] -elif mode_concentrated: - configs = list( - itertools.product( - [768], - [1536, 7168, 16384], - [128], - [None], - [fp8_type_], - [ - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - ], - ) - ) + list( - itertools.product( - [768 * 8], - [2048], - [128], - [48], - [fp8_type_], - [ - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="balanced", - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="imbalanced", - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="extreme", - ), - ], - ) - ) -else: - configs = list( - itertools.product( - [1, 4, 16, 64, 256, 768, 2048, 8192, 16384], - [1536, 7168, 16384], - [128], - [None], - [fp8_type_], - [ - dict( - column_major_scales=False, - scale_tma_aligned=False, - scale_ue8m0=False, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=False, - scale_ue8m0=False, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=False, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - ], - ) - ) + list( - itertools.product( - [1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8], - [2048], - [128], - [8, 16, 32, 48], - [fp8_type_], - [ - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="balanced", - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="imbalanced", - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="extreme", - ), - ], - ) +num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] +hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 +group_size_range = [128] # For DeepSeek V3/R1 +# TODO test int8 +dst_dtype_range = [fp8_type_] +flags_range = [ + dict( + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=False, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + ), +] + + +configs = list( + itertools.product( + num_tokens_range, + hidden_dim_range, + group_size_range, + dst_dtype_range, + flags_range, ) +) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=[ - "num_tokens", - "hidden_dim", - "group_size", - "num_ranks", - "dst_dtype", - "flags", - ], + x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"], x_vals=configs, line_arg="provider", line_vals=["triton", "sglang"], - # Triton has multi kernels and we only report the time for the core one - line_names=["Triton (Inaccurate)", "SGL Kernel"], + line_names=["Triton", "SGL Kernel"], styles=[("blue", "-"), ("green", "-")], ylabel="us", plot_name="per-token-group-quant-8bit-performance", args={}, ) ) -def benchmark( - num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider -): - print( - f"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}" - ) +def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider): + if flags["scale_ue8m0"] and group_size != 128: + return - x, masked_m = create_per_token_group_quant_test_data( - num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags - ) + device = torch.device("cuda") + + x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) fn, kernel_names = { - "triton": ( - triton_per_token_group_quant_8bit, - "_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel", - ), + "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"), "sglang": ( sglang_per_token_group_quant_8bit, "per_token_group_quant_8bit_kernel", ), }[provider] - bench_fn = lambda: fn( - x=x, - masked_m=masked_m, - group_size=group_size, - dst_dtype=dst_dtype, - **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]}, - ) + bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags) - time_s = bench_kineto( - bench_fn, kernel_names=kernel_names, num_tests=300 if mode_concentrated else 30 - ) + time_s = bench_kineto(bench_fn, kernel_names=kernel_names) return time_s * 1e6 diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 599bcf591..4f95c9138 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -121,9 +121,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); m.def( - "sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size," - " float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()"); - m.impl("sgl_per_token_group_quant_8bit", torch::kCUDA, &sgl_per_token_group_quant_8bit); + "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"); + m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); + + m.def( + "sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps, float int8_min, float int8_max) -> ()"); + m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index 4c1d96a6a..82daaef19 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -1,396 +1,119 @@ #include -#include +#include #include #include #include "utils.h" -template __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; - static_assert( - (THREADS_PER_SUBWARP & (THREADS_PER_SUBWARP - 1)) == 0 && THREADS_PER_SUBWARP <= 16 && THREADS_PER_SUBWARP >= 1, - "THREADS_PER_SUBWARP must be 1, 2, 4, 8, or 16"); - - if constexpr (THREADS_PER_SUBWARP >= 16) { - val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); - } - if constexpr (THREADS_PER_SUBWARP >= 8) { - val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); - } - if constexpr (THREADS_PER_SUBWARP >= 4) { - val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); - } - if constexpr (THREADS_PER_SUBWARP >= 2) { - val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); - } + val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); return val; } -__device__ __forceinline__ float silu(const float& val) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - float half = 0.5f * val; - float t = __tanhf(half); - return half * (1.0f + t); -#else - return val / (1.0f + __expf(-val)); -#endif -} - -__device__ float2 fmul2_rn(float2 a, float2 b) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - return __fmul2_rn(a, b); -#else - float2 result; - result.x = a.x * b.x; - result.y = a.y * b.y; - return result; -#endif -} - -// Copied and modified from DeepEP -__forceinline__ __device__ float fast_pow2(int x) { - // We can ensure `-126 <= x and x <= 127` - uint32_t bits_x = (x + 127) << 23; - return *reinterpret_cast(&bits_x); -} - -// Copied and modified from DeepEP -__forceinline__ __device__ int fast_log2_ceil(float x) { - auto bits_x = *reinterpret_cast(&x); - auto exp_x = (bits_x >> 23) & 0xff; - auto man_bits = bits_x & ((1 << 23) - 1); - return exp_x - 127 + (man_bits != 0); -} - -// Copied and modified from DeepEP -template -__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv) { - constexpr float MAX_8BIT_INV = 1.0f / dtype_info::MAX; - if constexpr (ROUND_SCALE) { - auto exp_scale_inv = fast_log2_ceil(amax * MAX_8BIT_INV); - scale = fast_pow2(-exp_scale_inv); - scale_inv = fast_pow2(exp_scale_inv); - } else { - scale_inv = amax * MAX_8BIT_INV; - scale = dtype_info::MAX / amax; - } -} - -// Copied and modified from DeepEP -template > -__forceinline__ __device__ OUT_DTYPE_T extract_required_scale_format(float value) { - if constexpr (SCALE_UE8M0) { - return static_cast((*reinterpret_cast(&value)) >> 23); - } else { - return value; - } -} - -__device__ __forceinline__ void st_global(const int4* ptr, const int4& value) { - asm volatile( - "st.global.v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); -} - -__device__ __forceinline__ int4 ld_global_nc(const int4* ptr) { - int4 ret; - asm volatile("ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];" - : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) - : "l"(ptr)); - return ret; -} - -template -struct DtypeInfo; - -template <> -struct DtypeInfo { - static constexpr float MIN = -128; - static constexpr float MAX = 127; -}; - -template <> -struct DtypeInfo { - static constexpr float MIN = -448; - static constexpr float MAX = 448; -}; - -template -__device__ __forceinline__ int compute_input_group_start_offset( - int expert_idx, - int token_idx, - int hidden_dim_group_idx, - int hidden_size, - int num_tokens_per_expert, - int group_size) { - return expert_idx * num_tokens_per_expert * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) + - token_idx * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) + hidden_dim_group_idx * group_size; -} - -constexpr float LOCAL_ABSMAX_ABS = 1e-10; -constexpr uint32_t INPUT_PRIMARY_VEC_NUM_BYTES = 32; - -struct NaiveScheduler { - static void compute_exec_config( - int threads_per_subwarp, - int num_local_experts, - int hidden_dim_num_groups, - int num_groups, - int& subwarps_per_block, - dim3& grid, - dim3& block) { - subwarps_per_block = ([=]() -> int { - if (num_groups % 16 == 0) { - return 16; - } else if (num_groups % 8 == 0) { - return 8; - } else if (num_groups % 4 == 0) { - return 4; - } else if (num_groups % 2 == 0) { - return 2; - } - return 1; - })(); - grid = dim3(num_groups / subwarps_per_block); - block = dim3(subwarps_per_block * threads_per_subwarp); - } - - template - __device__ __forceinline__ static void execute( - const int subwarps_per_block, - const int hidden_dim_num_groups, - const int32_t* masked_m, - const int num_tokens_per_expert, - FUNC fn) { - constexpr int expert_idx = 0; - - const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP; - const int lane_id = threadIdx.x % THREADS_PER_SUBWARP; - - const int64_t block_group_id = blockIdx.x * subwarps_per_block; - const int64_t group_id = block_group_id + subwarp_id; - - int64_t input_group_start_offset; - if constexpr (!FUSE_SILU_AND_MUL) { - input_group_start_offset = group_id * GROUP_SIZE; - } - - const int token_idx = group_id / hidden_dim_num_groups; - // At the hidden_size dimension, we are handling idx-th group - const int hidden_dim_group_idx = group_id % hidden_dim_num_groups; - - if constexpr (FUSE_SILU_AND_MUL) { - const int hidden_size = hidden_dim_num_groups * GROUP_SIZE; - input_group_start_offset = compute_input_group_start_offset( - expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE); - } - - fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset); - } -}; - -struct MaskedLayoutScheduler { - // TODO can be dynamically determined (which may be good when num rank is small) - static constexpr int TOKEN_DIM_BLOCK_NUM_PER_EXPERT = 1024; - static constexpr int SUBWARPS_PER_BLOCK = 16; - - static void compute_exec_config( - int threads_per_subwarp, - int num_local_experts, - int hidden_dim_num_groups, - int num_groups, - int& subwarps_per_block, - dim3& grid, - dim3& block) { - subwarps_per_block = SUBWARPS_PER_BLOCK; - TORCH_CHECK(hidden_dim_num_groups % subwarps_per_block == 0); - grid = dim3(hidden_dim_num_groups / subwarps_per_block, TOKEN_DIM_BLOCK_NUM_PER_EXPERT, num_local_experts); - block = dim3(subwarps_per_block * threads_per_subwarp); - } - - template - __device__ __forceinline__ static void execute( - const int subwarps_per_block, - const int hidden_dim_num_groups, - const int32_t* masked_m, - const int num_tokens_per_expert, - FUNC fn) { - const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP; - const int lane_id = threadIdx.x % THREADS_PER_SUBWARP; - - const int expert_idx = blockIdx.z; - const int token_idx_start = blockIdx.y; - - const int64_t hidden_dim_group_idx = blockIdx.x * SUBWARPS_PER_BLOCK + subwarp_id; - - const int curr_expert_token_num = masked_m[expert_idx]; - - for (int token_idx = token_idx_start; token_idx < curr_expert_token_num; - token_idx += TOKEN_DIM_BLOCK_NUM_PER_EXPERT) { - const int hidden_size = hidden_dim_num_groups * GROUP_SIZE; - const int64_t input_group_start_offset = compute_input_group_start_offset( - expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE); - fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset); - } - } -}; - template < - typename SCHEDULER, - int GROUP_SIZE, - int THREADS_PER_SUBWARP, typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false, bool SCALE_UE8M0 = false, - bool FUSE_SILU_AND_MUL = false, typename scale_packed_t = std::conditional_t> __global__ void per_token_group_quant_8bit_kernel( const T* __restrict__ input, - DST_DTYPE* __restrict__ output_q, + void* __restrict__ output_q, scale_packed_t* __restrict__ output_s, - const int32_t* __restrict__ masked_m, - const int subwarps_per_block, - const int hidden_dim_num_groups, - // TODO can this be removed? - const int scale_expert_stride, - const int scale_hidden_stride, - const int num_tokens_per_expert) { - using dst_dtype_info = DtypeInfo; + const int group_size, + const int num_groups, + const int groups_per_block, + const float eps, + const float min_8bit, + const float max_8bit, + const int num_groups_per_row = 0, + const int scale_stride = 0) { + const int threads_per_group = 16; + const int64_t local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + const int64_t block_group_offset = global_group_id * group_size; + + float local_absmax = eps; + using scale_element_t = std::conditional_t; static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); - SCHEDULER::execute( - subwarps_per_block, - hidden_dim_num_groups, - masked_m, - num_tokens_per_expert, - [&](const int expert_idx, - const int token_idx, - const int hidden_dim_group_idx, - const int lane_id, - const int input_group_start_offset) { - constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(T); - constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(int4); + const T* group_input = input + block_group_offset; + DST_DTYPE* group_output = static_cast(output_q) + block_group_offset; + scale_element_t* scale_output; - const int offset_num_groups = expert_idx * num_tokens_per_expert * hidden_dim_num_groups + - token_idx * hidden_dim_num_groups + hidden_dim_group_idx; + if constexpr (IS_COLUMN_MAJOR) { + const int num_elems_per_pack = static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); + const int row_idx = global_group_id / num_groups_per_row; + const int col_idx_unpacked = global_group_id % num_groups_per_row; + const int col_idx = col_idx_unpacked / num_elems_per_pack; + const int pack_idx = col_idx_unpacked % num_elems_per_pack; + scale_output = reinterpret_cast(output_s) + + (col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx); + } else { + static_assert(!SCALE_UE8M0); + scale_output = output_s + global_group_id; + } - int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE]; - T* input_primary_vec = reinterpret_cast(input_primary_int4); - static_assert(sizeof(input_primary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_primary_int4)); + constexpr uint32_t vec_size = 16 / sizeof(T); + using vec_t = flashinfer::vec_t; - int4 input_secondary_int4[INPUT_PRIMARY_INT4_SIZE]; - T* input_secondary_vec = reinterpret_cast(input_secondary_int4); - static_assert(sizeof(input_secondary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_secondary_int4)); + const int32_t num_vec_elems = group_size / vec_size; + + for (int32_t i = lane_id; i < num_vec_elems; i += 16) { + vec_t input_vec; + input_vec.cast_load(group_input + i * vec_size); #pragma unroll - for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) { - input_primary_int4[j] = ld_global_nc( - reinterpret_cast(input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE) + j); - } - if constexpr (FUSE_SILU_AND_MUL) { - const int secondary_offset = hidden_dim_num_groups * GROUP_SIZE; -#pragma unroll - for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) { - input_secondary_int4[j] = ld_global_nc( - reinterpret_cast( - input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE + secondary_offset) + - j); - } - } + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); + float abs_val = fabsf(val); + local_absmax = fmaxf(local_absmax, abs_val); + } + } - constexpr int num_elems_per_pack = static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); - scale_element_t* scale_output; - if constexpr (IS_COLUMN_MAJOR) { - constexpr int scale_token_stride = 1; + local_absmax = GroupReduceMax(local_absmax, lane_id); - const int hidden_idx_packed = hidden_dim_group_idx / num_elems_per_pack; - const int pack_idx = hidden_dim_group_idx % num_elems_per_pack; - scale_output = reinterpret_cast(output_s) + - (expert_idx * scale_expert_stride * num_elems_per_pack + - hidden_idx_packed * scale_hidden_stride * num_elems_per_pack + - token_idx * scale_token_stride * num_elems_per_pack + pack_idx); - } else { - static_assert(!SCALE_UE8M0); - scale_output = output_s + offset_num_groups; - } + float y_s = local_absmax / max_8bit; + if constexpr (SCALE_UE8M0) { + y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f)))); + } - // can speed up if too slow - if constexpr (IS_COLUMN_MAJOR and SCALE_UE8M0) { - const int remainder_num_groups = hidden_dim_num_groups % num_elems_per_pack; - if ((remainder_num_groups != 0) and (hidden_dim_group_idx == hidden_dim_num_groups - 1) and - (lane_id < num_elems_per_pack - remainder_num_groups)) { - const int shift = 1 + lane_id; - *(scale_output + shift) = 0; - } - } + // TODO can optimize + scale_element_t y_s_quant; + if constexpr (SCALE_UE8M0) { + y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127); + } else { + y_s_quant = y_s; + } - float local_absmax = LOCAL_ABSMAX_ABS; + if (lane_id == 0) { + *scale_output = y_s_quant; + } + + for (int32_t i = lane_id; i < num_vec_elems; i += 16) { + vec_t input_vec; + input_vec.cast_load(group_input + i * vec_size); #pragma unroll - for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) { - float val; - if constexpr (FUSE_SILU_AND_MUL) { - // TODO maybe vectorize - T val_lowprec = static_cast(silu(static_cast(input_primary_vec[j]))) * input_secondary_vec[j]; - val = static_cast(val_lowprec); - input_primary_vec[j] = val_lowprec; - } else { - val = static_cast(input_primary_vec[j]); - } - - float abs_val = fabsf(val); - local_absmax = fmaxf(local_absmax, abs_val); - } - - local_absmax = GroupReduceMax(local_absmax, lane_id); - - float y_scale, y_scale_inv; - calculate_fp8_scales(local_absmax, y_scale, y_scale_inv); - float2 y_scale_repeated = {y_scale, y_scale}; - - if (lane_id == 0) { - *scale_output = extract_required_scale_format(y_scale_inv); - } - - int4 output_buf; - static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof(DST_DTYPE)); - - if constexpr (std::is_same_v) { - const auto output_buf_ptr = reinterpret_cast<__nv_fp8x2_storage_t*>(&output_buf); - static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof(__nv_fp8x2_storage_t)); - static_assert(INPUT_PRIMARY_VEC_SIZE % 2 == 0); - -#pragma unroll - for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; j += 2) { - float2 inputx2 = {static_cast(input_primary_vec[j]), static_cast(input_primary_vec[j + 1])}; - float2 outputx2 = fmul2_rn(inputx2, y_scale_repeated); - output_buf_ptr[j / 2] = __nv_cvt_float2_to_fp8x2(outputx2, __NV_SATFINITE, __NV_E4M3); - } - } else { - const auto output_buf_ptr = reinterpret_cast(&output_buf); - -#pragma unroll - for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) { - float val = static_cast(input_primary_vec[j]); - float q_val = fminf(fmaxf(val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX); - output_buf_ptr[j] = DST_DTYPE(q_val); - } - } - - st_global( - reinterpret_cast(output_q + offset_num_groups * GROUP_SIZE + lane_id * INPUT_PRIMARY_VEC_SIZE), - output_buf); - }); + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); + float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit); + group_output[i * vec_size + j] = DST_DTYPE(q_val); + } + } } void sgl_per_token_group_quant_8bit( - // vanilla: (num_tokens, hidden_size) - // fuse_silu_and_mul: (num_tokens, hidden_size * 2) - // fuse_silu_and_mul + masked_layout: (num_experts, num_tokens-with-padding, hidden_size * 2) torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, @@ -398,113 +121,120 @@ void sgl_per_token_group_quant_8bit( double eps, double min_8bit, double max_8bit, - bool scale_ue8m0, - bool fuse_silu_and_mul, - const std::optional& masked_m) { + bool scale_ue8m0 = false) { CHECK_INPUT(input); CHECK_INPUT(output_q); - TORCH_CHECK(input.numel() > 0); - TORCH_CHECK(std::abs(LOCAL_ABSMAX_ABS - eps) < 1e-13); + const int num_groups = input.numel() / group_size; CHECK_EQ(input.numel() % group_size, 0); - const int num_groups = static_cast(input.numel()) / group_size / (fuse_silu_and_mul ? 2 : 1); - - const bool masked_layout = masked_m.has_value(); - TORCH_CHECK(output_s.dim() == (masked_layout ? 3 : 2)); - - const int num_local_experts = masked_layout ? input.size(0) : 1; + CHECK_EQ(output_s.dim(), 2); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + constexpr int THREADS_PER_GROUP = 16; + + int groups_per_block = 1; + + if (num_groups % 16 == 0) { + groups_per_block = 16; + } else if (num_groups % 8 == 0) { + groups_per_block = 8; + } else if (num_groups % 4 == 0) { + groups_per_block = 4; + } else if (num_groups % 2 == 0) { + groups_per_block = 2; + } + auto dst_type = output_q.scalar_type(); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; - const bool is_column_major = output_s.stride(-2) < output_s.stride(-1); - const int hidden_dim_num_groups = static_cast(output_q.size(-1)) / group_size; - const int num_tokens_per_expert = static_cast(output_q.size(-2)); - const int scale_expert_stride = masked_layout ? static_cast(output_s.stride(0)) : 0; - const int scale_hidden_stride = static_cast(output_s.stride(-1)); + const bool is_column_major = output_s.stride(0) < output_s.stride(1); + const int hidden_dim = input.size(input.dim() - 1); + const int num_groups_per_row = hidden_dim / group_size; + const int scale_stride = output_s.stride(1); -#define LAUNCH_KERNEL_INNER(SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, output_s_dtype, ...) \ - do { \ - int subwarps_per_block; \ - dim3 grid, block; \ - SCHEDULER::compute_exec_config( \ - THREADS_PER_SUBWARP, num_local_experts, hidden_dim_num_groups, num_groups, subwarps_per_block, grid, block); \ - \ - per_token_group_quant_8bit_kernel \ - <<>>( \ - static_cast(input.data_ptr()), \ - static_cast(output_q.data_ptr()), \ - static_cast(output_s.data_ptr()), \ - static_cast(masked_m.has_value() ? masked_m->data_ptr() : 0), \ - subwarps_per_block, \ - hidden_dim_num_groups, \ - scale_expert_stride, \ - scale_hidden_stride, \ - num_tokens_per_expert); \ +#define LAUNCH_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + if (is_column_major) { \ + if (scale_ue8m0) { \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit, \ + num_groups_per_row, \ + scale_stride); \ + } else { \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit, \ + num_groups_per_row, \ + scale_stride); \ + } \ + } else { \ + assert(!scale_ue8m0); \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit); \ + } \ } while (0) -#define LAUNCH_KERNEL(GROUP_SIZE, T, DST_DTYPE) \ - do { \ - constexpr int THREADS_PER_SUBWARP = GROUP_SIZE / 16; \ - TORCH_CHECK(THREADS_PER_SUBWARP* INPUT_PRIMARY_VEC_NUM_BYTES == group_size * sizeof(T)); \ - \ - using dst_dtype_info = DtypeInfo; \ - CHECK_EQ(dst_dtype_info::MIN, min_8bit); \ - CHECK_EQ(dst_dtype_info::MAX, max_8bit); \ - \ - if (is_column_major) { \ - if (scale_ue8m0) { \ - if (fuse_silu_and_mul) { \ - if (masked_layout) { \ - LAUNCH_KERNEL_INNER( \ - MaskedLayoutScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \ - } else { \ - LAUNCH_KERNEL_INNER( \ - NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \ - } \ - } else { \ - LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true); \ - } \ - } else { \ - LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, true); \ - } \ - } else { \ - LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, false); \ - } \ - } while (0) - -#define LAUNCH_KERNEL_OUTER(...) \ - switch (group_size) { \ - case 16: \ - LAUNCH_KERNEL(16, __VA_ARGS__); \ - break; \ - case 32: \ - LAUNCH_KERNEL(32, __VA_ARGS__); \ - break; \ - case 64: \ - LAUNCH_KERNEL(64, __VA_ARGS__); \ - break; \ - case 128: \ - LAUNCH_KERNEL(128, __VA_ARGS__); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported group_size"); \ - } \ - while (0) - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), scalar_t, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { if (dst_type == at::ScalarType::Char) { - LAUNCH_KERNEL_OUTER(scalar_t, int8_t); + LAUNCH_KERNEL(scalar_t, int8_t); return true; } else if (dst_type == at::ScalarType::Float8_e4m3fn) { - LAUNCH_KERNEL_OUTER(scalar_t, c10::Float8_e4m3fn); + LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3); return true; } return false; }); #undef LAUNCH_KERNEL -#undef LAUNCH_KERNEL_INNER +} + +void sgl_per_token_group_quant_int8( + torch::Tensor input, + torch::Tensor output_q, + torch::Tensor output_s, + int64_t group_size, + double eps, + double int8_min, + double int8_max) { + sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max); +} + +void sgl_per_token_group_quant_fp8( + torch::Tensor input, + torch::Tensor output_q, + torch::Tensor output_s, + int64_t group_size, + double eps, + double fp8_min, + double fp8_max, + bool scale_ue8m0) { + sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 1cd85c911..a13af546a 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -207,17 +207,23 @@ torch::Tensor fp8_blockwise_scaled_mm( const torch::Dtype& out_dtype); void scaled_fp4_quant( torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); -void sgl_per_token_group_quant_8bit( +void sgl_per_token_group_quant_fp8( at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, double eps, - double min_8bit, - double max_8bit, - bool scale_ue8m0, - bool fuse_silu_and_mul, - const std::optional& masked_m); + double fp8_min, + double fp8_max, + bool scale_ue8m0); +void sgl_per_token_group_quant_int8( + at::Tensor input, + at::Tensor output_q, + at::Tensor output_s, + int64_t group_size, + double eps, + double int8_min, + double int8_max); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void bmm_fp8( diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index f628af249..76c87d30b 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -58,7 +58,8 @@ from sgl_kernel.gemm import ( scaled_fp4_grouped_quant, scaled_fp4_quant, sgl_per_tensor_quant_fp8, - sgl_per_token_group_quant_8bit, + sgl_per_token_group_quant_fp8, + sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, shuffle_rows, silu_and_mul_scaled_fp4_grouped_quant, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 1a4c5d2d5..36672877d 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -98,7 +98,7 @@ def dsv3_fused_a_gemm( return output -def sgl_per_token_group_quant_8bit( +def sgl_per_token_group_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, @@ -106,21 +106,24 @@ def sgl_per_token_group_quant_8bit( eps: float, fp8_min: float, fp8_max: float, - scale_ue8m0: bool = False, - fuse_silu_and_mul: bool = False, - masked_m: Optional[torch.Tensor] = None, + scale_ue8m0: bool, ) -> None: - torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default( - input, - output_q, - output_s, - group_size, - eps, - fp8_min, - fp8_max, - scale_ue8m0, - fuse_silu_and_mul, - masked_m, + torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( + input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 + ) + + +def sgl_per_token_group_quant_int8( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, + group_size: int, + eps: float, + int8_min: float, + int8_max: float, +) -> None: + torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default( + input, output_q, output_s, group_size, eps, int8_min, int8_max ) diff --git a/sgl-kernel/python/sgl_kernel/test_utils.py b/sgl-kernel/python/sgl_kernel/test_utils.py deleted file mode 100644 index ede113fd0..000000000 --- a/sgl-kernel/python/sgl_kernel/test_utils.py +++ /dev/null @@ -1,125 +0,0 @@ -import torch - - -def create_per_token_group_quant_test_data(num_tokens, hidden_dim, num_ranks, flags): - device = torch.device("cuda") - dtype = torch.bfloat16 - - seed = num_tokens * 10000 + hidden_dim - gen_cpu = torch.Generator(device="cpu") - gen_cpu.manual_seed(seed) - gen_cuda = torch.Generator(device="cuda") - gen_cuda.manual_seed(seed) - - if flags["fuse_silu_and_mul"]: - effective_hidden_dim = hidden_dim * 2 - else: - effective_hidden_dim = hidden_dim - del hidden_dim - - if (masked_layout_mode := flags["masked_layout_mode"]) is not None: - num_max_dispatch_tokens_per_rank = 768 - num_global_experts = 288 - num_local_experts, remainder = divmod(num_global_experts, num_ranks) - assert remainder == 0 - - # mimic DeepEP low_latency_dispatch output - x = torch.randn( - num_local_experts, - num_max_dispatch_tokens_per_rank * num_ranks, - effective_hidden_dim, - device=device, - dtype=dtype, - generator=gen_cuda, - ) - - if masked_layout_mode == "balanced": - masked_m = _compute_balanced_split(num_tokens, num_local_experts) - elif masked_layout_mode == "imbalanced": - masked_m = _compute_imbalanced_split( - num_tokens, num_local_experts, gen_cpu=gen_cpu - ) - elif masked_layout_mode == "extreme": - masked_m = torch.tensor( - [num_tokens] + [0] * (num_local_experts - 1), dtype=torch.int - ) - else: - raise NotImplementedError - print(f"{masked_layout_mode=} {masked_m=} {x.shape=}") - - masked_m = masked_m.to(device) - - return x, masked_m - else: - x = torch.randn( - num_tokens, - effective_hidden_dim, - device=device, - dtype=dtype, - generator=gen_cuda, - ) - x[torch.randn(x.shape, device=device, generator=gen_cuda) < 0.001] *= 10 - return x, None - - -def _compute_balanced_split(total: int, arr_len: int): - base = total // arr_len - remainder = total % arr_len - ans = [base + 1 if i < remainder else base for i in range(arr_len)] - assert sum(ans) == total - return torch.tensor(ans, dtype=torch.int) - - -def _compute_imbalanced_split( - total: int, arr_len: int, gen_cpu, dtype=torch.int -) -> list[int]: - # can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is - noise_raw = torch.rand(arr_len, generator=gen_cpu) ** 3 - - noise = noise_raw / noise_raw.sum() - ans = (noise * total).round().to(dtype) - - diff = total - ans.sum().item() - while diff != 0: - idx = torch.randint(0, arr_len, (1,), generator=gen_cpu).item() - if diff > 0: - ans[idx] += 1 - diff -= 1 - elif diff < 0 and ans[idx] > 0: - ans[idx] -= 1 - diff += 1 - - assert sum(ans) == total - return ans - - -def assert_all_close_or_tiny_diff(a: torch.Tensor, b: torch.Tensor): - assert (a.shape == b.shape) and ( - a.dtype == b.dtype - ), f"{a.shape=} {b.shape=} {a.dtype=} {b.dtype=}" - numel = a.numel() - - if a.dtype == torch.float8_e4m3fn: - a_u8 = a.view(torch.uint8) - b_u8 = b.view(torch.uint8) - diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs() - - count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item() - count_tiny_diff = (diff_u8 == 1).sum().item() - count_large_diff = (diff_u8 >= 2).sum().item() - elif a.dtype == torch.int8: - diff = (a.to(torch.int16) - a.to(torch.int16)).abs() - count_diff_sign = ((a >= 0) & (b < 0)).sum().item() - count_tiny_diff = (diff == 1).sum().item() - count_large_diff = (diff >= 2).sum().item() - else: - raise NotImplementedError - - assert ( - (count_diff_sign == 0) - and (count_large_diff == 0) - and ( - (count_tiny_diff / numel < 0.005) - or ((count_tiny_diff / numel < 0.04) and (numel <= 4096)) - ) - ), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=} {a=} {b=}" diff --git a/sgl-kernel/tests/test_per_token_group_quant_8bit.py b/sgl-kernel/tests/test_per_token_group_quant_8bit.py index f47c78414..778d14d31 100644 --- a/sgl-kernel/tests/test_per_token_group_quant_8bit.py +++ b/sgl-kernel/tests/test_per_token_group_quant_8bit.py @@ -1,199 +1,96 @@ import itertools -import os -import time -from pathlib import Path import pytest import torch -from sgl_kernel.test_utils import ( - assert_all_close_or_tiny_diff, - create_per_token_group_quant_test_data, -) +from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_8bit as triton_per_token_group_quant_8bit, ) from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit -from sglang.srt.utils import get_bool_env_var, is_hip +from sglang.srt.layers.quantization.utils import assert_fp8_all_close +from sglang.srt.utils import is_hip _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn -configs = list( - itertools.product( - [1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192], # num_tokens - [128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384], # hidden_dim - [16, 32, 64, 128], # group_size - [None], # num_ranks - [fp8_type_, torch.int8], # dtype - [ - dict( - column_major_scales=False, - scale_tma_aligned=False, - scale_ue8m0=False, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=False, - scale_ue8m0=False, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=False, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=False, - masked_layout_mode=None, - ), - ], - ) -) + list( - itertools.product( - [1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8], - # TODO support more - [2048], - [128], - [8, 16, 32, 48], - [fp8_type_], - [ - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode=None, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="balanced", - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="imbalanced", - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - fuse_silu_and_mul=True, - masked_layout_mode="extreme", - ), - ], - ) -) - @pytest.mark.parametrize( - "num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags", configs + "num_tokens, hidden_dim, group_size, dst_dtype, flags", + list( + itertools.product( + [127, 128, 512, 1024, 4096, 8192], # num_tokens + [256, 512, 1024, 2048, 4096], # hidden_dim + [8, 16, 32, 64, 128], # group_size + # TODO test int8 + [fp8_type_], # dtype + [ + dict( + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=False, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + ), + ], + ) + ), ) def test_per_token_group_quant_with_column_major( num_tokens, hidden_dim, group_size, - num_ranks, dst_dtype, flags, ): - print( - f"{num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=}" - ) - - arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device()) - if flags["scale_ue8m0"] and (arch_major <= 9): - pytest.skip("Only Blackwell need ue8m0 fusion") - return - - if (flags["scale_ue8m0"] and (group_size != 128)) or ( - (dst_dtype == torch.int8) and flags["column_major_scales"] - ): + if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)): pytest.skip() return + if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL: + pytest.skip("scale_ue8m0 only supported on Blackwell") + return - x, masked_m = create_per_token_group_quant_test_data( - num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags - ) - - # print("hack data!!!") - # x = torch.full_like(x, fill_value=100) + x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) execute_kwargs = dict( x=x, - masked_m=masked_m, group_size=group_size, eps=1e-10, dst_dtype=dst_dtype, - **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]}, + **flags, ) - def _postprocess(x_q, x_s): - if masked_m is not None: - print(f"Mask tokens after {masked_m} to be zero") - for i in range(len(masked_m)): - x_q[i, masked_m[i] :, :] = 0 - x_s[i, masked_m[i] :, :] = 0 - return x_q, x_s + x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs) + x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs) - x_q_triton, x_s_triton = _postprocess( - *triton_per_token_group_quant_8bit(**execute_kwargs) + # torch.set_printoptions(profile="full") + # print(f"{x_q_triton=}") + # print(f"{x_s_triton=}") + # print(f"{x_q_sglang=}") + # print(f"{x_s_sglang=}") + # torch.set_printoptions(profile="default") + + assert_fp8_all_close(x_q_triton, x_q_sglang) + torch.testing.assert_close( + x_s_triton.contiguous(), + x_s_sglang.contiguous(), + rtol=1e-3, + atol=1e-5, + msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}", ) - x_q_sglang, x_s_sglang = _postprocess( - *sglang_per_token_group_quant_8bit(**execute_kwargs) - ) - - try: - assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang) - torch.testing.assert_close( - x_s_triton.contiguous(), - x_s_sglang.contiguous(), - rtol=1e-3, - atol=1e-5, - msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}", - ) - except AssertionError: - # torch.set_printoptions(profile="full") - print( - f"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}" - ) - print(f"{x=}") - print(f"{masked_m=}") - print(f"{x_q_triton=}") - print(f"{x_s_triton=}") - print(f"{x_q_sglang=}") - print(f"{x_s_sglang=}") - # torch.set_printoptions(profile="default") - - # if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "": - # import matplotlib.pyplot as plt - # - # base_stem = time.time() - # for name, value in [ - # ("x_q", x_q_triton != x_q_sglang), - # ("x_s", x_s_triton != x_s_sglang), - # ]: - # value = value.reshape((-1, value.shape[-1])) - # plt.figure(figsize=(20, 20)) - # plt.imshow((value * 1.0).cpu().numpy()) - # p = Path(d) / f"{base_stem}_{name}.png" - # print(f"Write diff to {p}", flush=True) - # plt.savefig(p) - - raise if __name__ == "__main__":