diff --git a/python/sglang/srt/bench_utils.py b/python/sglang/srt/bench_utils.py index e9f7fcbb4..ea400bfa8 100644 --- a/python/sglang/srt/bench_utils.py +++ b/python/sglang/srt/bench_utils.py @@ -1,4 +1,5 @@ import os +import re import sys from contextlib import nullcontext @@ -108,7 +109,8 @@ def bench_kineto( if not with_multiple_kernels: for name in kernel_names: assert ( - sum([name in line for line in prof_lines]) == 1 + sum([int(re.search(name, line) is not None) for line in prof_lines]) + == 1 ), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})" # Save chrome traces @@ -122,7 +124,7 @@ def bench_kineto( total_time = 0 total_num = 0 for line in prof_lines: - if name in line: + if re.search(name, line) is not None: time_str = line.split()[-2] num_str = line.split()[-1] for unit, scale in units.items(): diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index f0512365b..9c30dc060 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -43,11 +43,17 @@ _is_cpu = is_cpu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: - from sgl_kernel import ( - sgl_per_tensor_quant_fp8, - sgl_per_token_group_quant_fp8, - sgl_per_token_quant_fp8, - ) + from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8 + + # Temporary + try: + from sgl_kernel import sgl_per_token_group_quant_8bit + + enable_sgl_per_token_group_quant_8bit = True + except ImportError: + from sgl_kernel import sgl_per_token_group_quant_fp8 + + enable_sgl_per_token_group_quant_8bit = False if _is_hip: if _use_aiter: @@ -496,9 +502,24 @@ def sglang_per_token_group_quant_fp8( ) if x.shape[0] > 0: - sgl_per_token_group_quant_fp8( - x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 - ) + # Temporary + if enable_sgl_per_token_group_quant_8bit: + sgl_per_token_group_quant_8bit( + x, + x_q, + x_s, + group_size, + eps, + fp8_min, + fp8_max, + scale_ue8m0, + fuse_silu_and_mul, + masked_m, + ) + else: + sgl_per_token_group_quant_fp8( + x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 + ) return x_q, x_s diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py index 7c6c3dbd4..826d16e3c 100644 --- a/python/sglang/srt/layers/quantization/int8_kernel.py +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda _is_cuda = is_cuda() if _is_cuda: - from sgl_kernel import sgl_per_token_group_quant_int8 + # Temporary + try: + from sgl_kernel import sgl_per_token_group_quant_8bit + except ImportError: + from sgl_kernel import ( + sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit, + ) logger = logging.getLogger(__name__) @@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8( dtype=torch.float32, ) - sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) + sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max) return x_q, x_s diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py index 3f37a3248..7237312ce 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -1,10 +1,12 @@ import itertools +import os import time from functools import partial from pathlib import Path import torch import triton +from sgl_kernel.test_utils import create_per_token_group_quant_test_data from sglang.srt.bench_utils import bench_kineto from sglang.srt.layers.quantization.fp8_kernel import ( @@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn +mode_concentrated = os.environ.get("SGLANG_BENCH_MODE", "") == "concentrated" -num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] -hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 -group_size_range = [128] # For DeepSeek V3/R1 -# TODO test int8 -dst_dtype_range = [fp8_type_] -flags_range = [ - dict( - column_major_scales=False, - scale_tma_aligned=False, - scale_ue8m0=False, - ), - dict( - column_major_scales=True, - scale_tma_aligned=False, - scale_ue8m0=False, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=False, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - ), -] - - -configs = list( - itertools.product( - num_tokens_range, - hidden_dim_range, - group_size_range, - dst_dtype_range, - flags_range, +if int(os.environ.get("SGLANG_NSYS_PROFILING", "0")): + # configs = [[ + # 768, + # 16384, + # 128, + # None, + # fp8_type_, + # dict( + # column_major_scales=True, + # scale_tma_aligned=True, + # scale_ue8m0=True, + # fuse_silu_and_mul=False, + # masked_layout_mode=None, + # ), + # ]] + configs = [ + [ + 768 * 8, + 2048, + 128, + 48, + fp8_type_, + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + # masked_layout_mode=None, + masked_layout_mode="balanced", + # masked_layout_mode="extreme", + ), + ] + ] +elif mode_concentrated: + configs = list( + itertools.product( + [768], + [1536, 7168, 16384], + [128], + [None], + [fp8_type_], + [ + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + ], + ) + ) + list( + itertools.product( + [768 * 8], + [2048], + [128], + [48], + [fp8_type_], + [ + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="balanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="imbalanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="extreme", + ), + ], + ) + ) +else: + configs = list( + itertools.product( + [1, 4, 16, 64, 256, 768, 2048, 8192, 16384], + [1536, 7168, 16384], + [128], + [None], + [fp8_type_], + [ + dict( + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=False, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + ], + ) + ) + list( + itertools.product( + [1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8], + [2048], + [128], + [8, 16, 32, 48], + [fp8_type_], + [ + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="balanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="imbalanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="extreme", + ), + ], + ) ) -) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"], + x_names=[ + "num_tokens", + "hidden_dim", + "group_size", + "num_ranks", + "dst_dtype", + "flags", + ], x_vals=configs, line_arg="provider", line_vals=["triton", "sglang"], - line_names=["Triton", "SGL Kernel"], + # Triton has multi kernels and we only report the time for the core one + line_names=["Triton (Inaccurate)", "SGL Kernel"], styles=[("blue", "-"), ("green", "-")], ylabel="us", plot_name="per-token-group-quant-8bit-performance", args={}, ) ) -def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider): - if flags["scale_ue8m0"] and group_size != 128: - return +def benchmark( + num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider +): + print( + f"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}" + ) - device = torch.device("cuda") - - x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) + x, masked_m = create_per_token_group_quant_test_data( + num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags + ) fn, kernel_names = { - "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"), + "triton": ( + triton_per_token_group_quant_8bit, + "_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel", + ), "sglang": ( sglang_per_token_group_quant_8bit, "per_token_group_quant_8bit_kernel", ), }[provider] - bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags) + bench_fn = lambda: fn( + x=x, + masked_m=masked_m, + group_size=group_size, + dst_dtype=dst_dtype, + **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]}, + ) - time_s = bench_kineto(bench_fn, kernel_names=kernel_names) + time_s = bench_kineto( + bench_fn, kernel_names=kernel_names, num_tests=300 if mode_concentrated else 30 + ) return time_s * 1e6 diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 8ff06f454..54587b1be 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); m.def( - "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," - " float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"); - m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); - - m.def( - "sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size," - " float eps, float int8_min, float int8_max) -> ()"); - m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8); + "sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()"); + m.impl("sgl_per_token_group_quant_8bit", torch::kCUDA, &sgl_per_token_group_quant_8bit); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index 474164ce6..1944e6d37 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -1,119 +1,396 @@ #include -#include +#include #include #include #include "utils.h" +template __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { unsigned mask = 0xffff; - val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); - val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); - val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); - val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); + static_assert( + (THREADS_PER_SUBWARP & (THREADS_PER_SUBWARP - 1)) == 0 && THREADS_PER_SUBWARP <= 16 && THREADS_PER_SUBWARP >= 1, + "THREADS_PER_SUBWARP must be 1, 2, 4, 8, or 16"); + + if constexpr (THREADS_PER_SUBWARP >= 16) { + val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); + } + if constexpr (THREADS_PER_SUBWARP >= 8) { + val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); + } + if constexpr (THREADS_PER_SUBWARP >= 4) { + val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); + } + if constexpr (THREADS_PER_SUBWARP >= 2) { + val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); + } return val; } +__device__ __forceinline__ float silu(const float& val) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + float half = 0.5f * val; + float t = __tanhf(half); + return half * (1.0f + t); +#else + return val / (1.0f + __expf(-val)); +#endif +} + +__device__ float2 fmul2_rn(float2 a, float2 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + return __fmul2_rn(a, b); +#else + float2 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; +#endif +} + +// Copied and modified from DeepEP +__forceinline__ __device__ float fast_pow2(int x) { + // We can ensure `-126 <= x and x <= 127` + uint32_t bits_x = (x + 127) << 23; + return *reinterpret_cast(&bits_x); +} + +// Copied and modified from DeepEP +__forceinline__ __device__ int fast_log2_ceil(float x) { + auto bits_x = *reinterpret_cast(&x); + auto exp_x = (bits_x >> 23) & 0xff; + auto man_bits = bits_x & ((1 << 23) - 1); + return exp_x - 127 + (man_bits != 0); +} + +// Copied and modified from DeepEP +template +__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv) { + constexpr float MAX_8BIT_INV = 1.0f / dtype_info::MAX; + if constexpr (ROUND_SCALE) { + auto exp_scale_inv = fast_log2_ceil(amax * MAX_8BIT_INV); + scale = fast_pow2(-exp_scale_inv); + scale_inv = fast_pow2(exp_scale_inv); + } else { + scale_inv = amax * MAX_8BIT_INV; + scale = dtype_info::MAX / amax; + } +} + +// Copied and modified from DeepEP +template > +__forceinline__ __device__ OUT_DTYPE_T extract_required_scale_format(float value) { + if constexpr (SCALE_UE8M0) { + return static_cast((*reinterpret_cast(&value)) >> 23); + } else { + return value; + } +} + +__device__ __forceinline__ void st_global(const int4* ptr, const int4& value) { + asm volatile( + "st.global.v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); +} + +__device__ __forceinline__ int4 ld_global_nc(const int4* ptr) { + int4 ret; + asm volatile("ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) + : "l"(ptr)); + return ret; +} + +template +struct DtypeInfo; + +template <> +struct DtypeInfo { + static constexpr float MIN = -128; + static constexpr float MAX = 127; +}; + +template <> +struct DtypeInfo { + static constexpr float MIN = -448; + static constexpr float MAX = 448; +}; + +template +__device__ __forceinline__ int compute_input_group_start_offset( + int expert_idx, + int token_idx, + int hidden_dim_group_idx, + int hidden_size, + int num_tokens_per_expert, + int group_size) { + return expert_idx * num_tokens_per_expert * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) + + token_idx * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) + hidden_dim_group_idx * group_size; +} + +constexpr float LOCAL_ABSMAX_ABS = 1e-10; +constexpr uint32_t INPUT_PRIMARY_VEC_NUM_BYTES = 32; + +struct NaiveScheduler { + static void compute_exec_config( + int threads_per_subwarp, + int num_local_experts, + int hidden_dim_num_groups, + int num_groups, + int& subwarps_per_block, + dim3& grid, + dim3& block) { + subwarps_per_block = ([=]() -> int { + if (num_groups % 16 == 0) { + return 16; + } else if (num_groups % 8 == 0) { + return 8; + } else if (num_groups % 4 == 0) { + return 4; + } else if (num_groups % 2 == 0) { + return 2; + } + return 1; + })(); + grid = dim3(num_groups / subwarps_per_block); + block = dim3(subwarps_per_block * threads_per_subwarp); + } + + template + __device__ __forceinline__ static void execute( + const int subwarps_per_block, + const int hidden_dim_num_groups, + const int32_t* masked_m, + const int num_tokens_per_expert, + FUNC fn) { + constexpr int expert_idx = 0; + + const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP; + const int lane_id = threadIdx.x % THREADS_PER_SUBWARP; + + const int64_t block_group_id = blockIdx.x * subwarps_per_block; + const int64_t group_id = block_group_id + subwarp_id; + + int64_t input_group_start_offset; + if constexpr (!FUSE_SILU_AND_MUL) { + input_group_start_offset = group_id * GROUP_SIZE; + } + + const int token_idx = group_id / hidden_dim_num_groups; + // At the hidden_size dimension, we are handling idx-th group + const int hidden_dim_group_idx = group_id % hidden_dim_num_groups; + + if constexpr (FUSE_SILU_AND_MUL) { + const int hidden_size = hidden_dim_num_groups * GROUP_SIZE; + input_group_start_offset = compute_input_group_start_offset( + expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE); + } + + fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset); + } +}; + +struct MaskedLayoutScheduler { + // TODO can be dynamically determined (which may be good when num rank is small) + static constexpr int TOKEN_DIM_BLOCK_NUM_PER_EXPERT = 1024; + static constexpr int SUBWARPS_PER_BLOCK = 16; + + static void compute_exec_config( + int threads_per_subwarp, + int num_local_experts, + int hidden_dim_num_groups, + int num_groups, + int& subwarps_per_block, + dim3& grid, + dim3& block) { + subwarps_per_block = SUBWARPS_PER_BLOCK; + TORCH_CHECK(hidden_dim_num_groups % subwarps_per_block == 0); + grid = dim3(hidden_dim_num_groups / subwarps_per_block, TOKEN_DIM_BLOCK_NUM_PER_EXPERT, num_local_experts); + block = dim3(subwarps_per_block * threads_per_subwarp); + } + + template + __device__ __forceinline__ static void execute( + const int subwarps_per_block, + const int hidden_dim_num_groups, + const int32_t* masked_m, + const int num_tokens_per_expert, + FUNC fn) { + const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP; + const int lane_id = threadIdx.x % THREADS_PER_SUBWARP; + + const int expert_idx = blockIdx.z; + const int token_idx_start = blockIdx.y; + + const int64_t hidden_dim_group_idx = blockIdx.x * SUBWARPS_PER_BLOCK + subwarp_id; + + const int curr_expert_token_num = masked_m[expert_idx]; + + for (int token_idx = token_idx_start; token_idx < curr_expert_token_num; + token_idx += TOKEN_DIM_BLOCK_NUM_PER_EXPERT) { + const int hidden_size = hidden_dim_num_groups * GROUP_SIZE; + const int64_t input_group_start_offset = compute_input_group_start_offset( + expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE); + fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset); + } + } +}; + template < + typename SCHEDULER, + int GROUP_SIZE, + int THREADS_PER_SUBWARP, typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false, bool SCALE_UE8M0 = false, + bool FUSE_SILU_AND_MUL = false, typename scale_packed_t = std::conditional_t> __global__ void per_token_group_quant_8bit_kernel( const T* __restrict__ input, - void* __restrict__ output_q, + DST_DTYPE* __restrict__ output_q, scale_packed_t* __restrict__ output_s, - const int group_size, - const int num_groups, - const int groups_per_block, - const float eps, - const float min_8bit, - const float max_8bit, - const int num_groups_per_row = 0, - const int scale_stride = 0) { - const int threads_per_group = 16; - const int64_t local_group_id = threadIdx.x / threads_per_group; - const int lane_id = threadIdx.x % threads_per_group; - - const int64_t block_group_id = blockIdx.x * groups_per_block; - const int64_t global_group_id = block_group_id + local_group_id; - const int64_t block_group_offset = global_group_id * group_size; - - float local_absmax = eps; - + const int32_t* __restrict__ masked_m, + const int subwarps_per_block, + const int hidden_dim_num_groups, + // TODO can this be removed? + const int scale_expert_stride, + const int scale_hidden_stride, + const int num_tokens_per_expert) { + using dst_dtype_info = DtypeInfo; using scale_element_t = std::conditional_t; static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); - const T* group_input = input + block_group_offset; - DST_DTYPE* group_output = static_cast(output_q) + block_group_offset; - scale_element_t* scale_output; + SCHEDULER::execute( + subwarps_per_block, + hidden_dim_num_groups, + masked_m, + num_tokens_per_expert, + [&](const int expert_idx, + const int token_idx, + const int hidden_dim_group_idx, + const int lane_id, + const int input_group_start_offset) { + constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(T); + constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(int4); - if constexpr (IS_COLUMN_MAJOR) { - const int num_elems_per_pack = static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); - const int row_idx = global_group_id / num_groups_per_row; - const int col_idx_unpacked = global_group_id % num_groups_per_row; - const int col_idx = col_idx_unpacked / num_elems_per_pack; - const int pack_idx = col_idx_unpacked % num_elems_per_pack; - scale_output = reinterpret_cast(output_s) + - (col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx); - } else { - static_assert(!SCALE_UE8M0); - scale_output = output_s + global_group_id; - } + const int offset_num_groups = expert_idx * num_tokens_per_expert * hidden_dim_num_groups + + token_idx * hidden_dim_num_groups + hidden_dim_group_idx; - constexpr uint32_t vec_size = 16 / sizeof(T); - using vec_t = flashinfer::vec_t; + int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE]; + T* input_primary_vec = reinterpret_cast(input_primary_int4); + static_assert(sizeof(input_primary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_primary_int4)); - const int32_t num_vec_elems = group_size / vec_size; - - for (int32_t i = lane_id; i < num_vec_elems; i += 16) { - vec_t input_vec; - input_vec.cast_load(group_input + i * vec_size); + int4 input_secondary_int4[INPUT_PRIMARY_INT4_SIZE]; + T* input_secondary_vec = reinterpret_cast(input_secondary_int4); + static_assert(sizeof(input_secondary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_secondary_int4)); #pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - float val = static_cast(input_vec[j]); - float abs_val = fabsf(val); - local_absmax = fmaxf(local_absmax, abs_val); - } - } + for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) { + input_primary_int4[j] = ld_global_nc( + reinterpret_cast(input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE) + j); + } + if constexpr (FUSE_SILU_AND_MUL) { + const int secondary_offset = hidden_dim_num_groups * GROUP_SIZE; +#pragma unroll + for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) { + input_secondary_int4[j] = ld_global_nc( + reinterpret_cast( + input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE + secondary_offset) + + j); + } + } - local_absmax = GroupReduceMax(local_absmax, lane_id); + constexpr int num_elems_per_pack = static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); + scale_element_t* scale_output; + if constexpr (IS_COLUMN_MAJOR) { + constexpr int scale_token_stride = 1; - float y_s = local_absmax / max_8bit; - if constexpr (SCALE_UE8M0) { - y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f)))); - } + const int hidden_idx_packed = hidden_dim_group_idx / num_elems_per_pack; + const int pack_idx = hidden_dim_group_idx % num_elems_per_pack; + scale_output = reinterpret_cast(output_s) + + (expert_idx * scale_expert_stride * num_elems_per_pack + + hidden_idx_packed * scale_hidden_stride * num_elems_per_pack + + token_idx * scale_token_stride * num_elems_per_pack + pack_idx); + } else { + static_assert(!SCALE_UE8M0); + scale_output = output_s + offset_num_groups; + } - // TODO can optimize - scale_element_t y_s_quant; - if constexpr (SCALE_UE8M0) { - y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127); - } else { - y_s_quant = y_s; - } + // can speed up if too slow + if constexpr (IS_COLUMN_MAJOR and SCALE_UE8M0) { + const int remainder_num_groups = hidden_dim_num_groups % num_elems_per_pack; + if ((remainder_num_groups != 0) and (hidden_dim_group_idx == hidden_dim_num_groups - 1) and + (lane_id < num_elems_per_pack - remainder_num_groups)) { + const int shift = 1 + lane_id; + *(scale_output + shift) = 0; + } + } - if (lane_id == 0) { - *scale_output = y_s_quant; - } - - for (int32_t i = lane_id; i < num_vec_elems; i += 16) { - vec_t input_vec; - input_vec.cast_load(group_input + i * vec_size); + float local_absmax = LOCAL_ABSMAX_ABS; #pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - float val = static_cast(input_vec[j]); - float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit); - group_output[i * vec_size + j] = DST_DTYPE(q_val); - } - } + for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) { + float val; + if constexpr (FUSE_SILU_AND_MUL) { + // TODO maybe vectorize + T val_lowprec = static_cast(silu(static_cast(input_primary_vec[j]))) * input_secondary_vec[j]; + val = static_cast(val_lowprec); + input_primary_vec[j] = val_lowprec; + } else { + val = static_cast(input_primary_vec[j]); + } + + float abs_val = fabsf(val); + local_absmax = fmaxf(local_absmax, abs_val); + } + + local_absmax = GroupReduceMax(local_absmax, lane_id); + + float y_scale, y_scale_inv; + calculate_fp8_scales(local_absmax, y_scale, y_scale_inv); + float2 y_scale_repeated = {y_scale, y_scale}; + + if (lane_id == 0) { + *scale_output = extract_required_scale_format(y_scale_inv); + } + + int4 output_buf; + static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof(DST_DTYPE)); + + if constexpr (std::is_same_v) { + const auto output_buf_ptr = reinterpret_cast<__nv_fp8x2_storage_t*>(&output_buf); + static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof(__nv_fp8x2_storage_t)); + static_assert(INPUT_PRIMARY_VEC_SIZE % 2 == 0); + +#pragma unroll + for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; j += 2) { + float2 inputx2 = {static_cast(input_primary_vec[j]), static_cast(input_primary_vec[j + 1])}; + float2 outputx2 = fmul2_rn(inputx2, y_scale_repeated); + output_buf_ptr[j / 2] = __nv_cvt_float2_to_fp8x2(outputx2, __NV_SATFINITE, __NV_E4M3); + } + } else { + const auto output_buf_ptr = reinterpret_cast(&output_buf); + +#pragma unroll + for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) { + float val = static_cast(input_primary_vec[j]); + float q_val = fminf(fmaxf(val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX); + output_buf_ptr[j] = DST_DTYPE(q_val); + } + } + + st_global( + reinterpret_cast(output_q + offset_num_groups * GROUP_SIZE + lane_id * INPUT_PRIMARY_VEC_SIZE), + output_buf); + }); } void sgl_per_token_group_quant_8bit( + // vanilla: (num_tokens, hidden_size) + // fuse_silu_and_mul: (num_tokens, hidden_size * 2) + // fuse_silu_and_mul + masked_layout: (num_experts, num_tokens-with-padding, hidden_size * 2) torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, @@ -121,120 +398,113 @@ void sgl_per_token_group_quant_8bit( double eps, double min_8bit, double max_8bit, - bool scale_ue8m0 = false) { + bool scale_ue8m0, + bool fuse_silu_and_mul, + const std::optional& masked_m) { CHECK_INPUT(input); CHECK_INPUT(output_q); + TORCH_CHECK(input.numel() > 0); - const int num_groups = input.numel() / group_size; + TORCH_CHECK(std::abs(LOCAL_ABSMAX_ABS - eps) < 1e-13); CHECK_EQ(input.numel() % group_size, 0); - CHECK_EQ(output_s.dim(), 2); + const int num_groups = static_cast(input.numel()) / group_size / (fuse_silu_and_mul ? 2 : 1); + + const bool masked_layout = masked_m.has_value(); + TORCH_CHECK(output_s.dim() == (masked_layout ? 3 : 2)); + + const int num_local_experts = masked_layout ? input.size(0) : 1; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - constexpr int THREADS_PER_GROUP = 16; - - int groups_per_block = 1; - - if (num_groups % 16 == 0) { - groups_per_block = 16; - } else if (num_groups % 8 == 0) { - groups_per_block = 8; - } else if (num_groups % 4 == 0) { - groups_per_block = 4; - } else if (num_groups % 2 == 0) { - groups_per_block = 2; - } - auto dst_type = output_q.scalar_type(); - const int num_blocks = num_groups / groups_per_block; - const int num_threads = groups_per_block * THREADS_PER_GROUP; - const bool is_column_major = output_s.stride(0) < output_s.stride(1); - const int hidden_dim = input.size(input.dim() - 1); - const int num_groups_per_row = hidden_dim / group_size; - const int scale_stride = output_s.stride(1); + const bool is_column_major = output_s.stride(-2) < output_s.stride(-1); + const int hidden_dim_num_groups = static_cast(output_q.size(-1)) / group_size; + const int num_tokens_per_expert = static_cast(output_q.size(-2)); + const int scale_expert_stride = masked_layout ? static_cast(output_s.stride(0)) : 0; + const int scale_hidden_stride = static_cast(output_s.stride(-1)); -#define LAUNCH_KERNEL(T, DST_DTYPE) \ - do { \ - dim3 grid(num_blocks); \ - dim3 block(num_threads); \ - if (is_column_major) { \ - if (scale_ue8m0) { \ - per_token_group_quant_8bit_kernel<<>>( \ - static_cast(input.data_ptr()), \ - output_q.data_ptr(), \ - static_cast(output_s.data_ptr()), \ - group_size, \ - num_groups, \ - groups_per_block, \ - (float)eps, \ - (float)min_8bit, \ - (float)max_8bit, \ - num_groups_per_row, \ - scale_stride); \ - } else { \ - per_token_group_quant_8bit_kernel<<>>( \ - static_cast(input.data_ptr()), \ - output_q.data_ptr(), \ - static_cast(output_s.data_ptr()), \ - group_size, \ - num_groups, \ - groups_per_block, \ - (float)eps, \ - (float)min_8bit, \ - (float)max_8bit, \ - num_groups_per_row, \ - scale_stride); \ - } \ - } else { \ - assert(!scale_ue8m0); \ - per_token_group_quant_8bit_kernel<<>>( \ - static_cast(input.data_ptr()), \ - output_q.data_ptr(), \ - static_cast(output_s.data_ptr()), \ - group_size, \ - num_groups, \ - groups_per_block, \ - (float)eps, \ - (float)min_8bit, \ - (float)max_8bit); \ - } \ +#define LAUNCH_KERNEL_INNER(SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, output_s_dtype, ...) \ + do { \ + int subwarps_per_block; \ + dim3 grid, block; \ + SCHEDULER::compute_exec_config( \ + THREADS_PER_SUBWARP, num_local_experts, hidden_dim_num_groups, num_groups, subwarps_per_block, grid, block); \ + \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), \ + static_cast(output_q.data_ptr()), \ + static_cast(output_s.data_ptr()), \ + static_cast(masked_m.has_value() ? masked_m->data_ptr() : 0), \ + subwarps_per_block, \ + hidden_dim_num_groups, \ + scale_expert_stride, \ + scale_hidden_stride, \ + num_tokens_per_expert); \ } while (0) - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { +#define LAUNCH_KERNEL(GROUP_SIZE, T, DST_DTYPE) \ + do { \ + constexpr int THREADS_PER_SUBWARP = GROUP_SIZE / 16; \ + TORCH_CHECK(THREADS_PER_SUBWARP* INPUT_PRIMARY_VEC_NUM_BYTES == group_size * sizeof(T)); \ + \ + using dst_dtype_info = DtypeInfo; \ + CHECK_EQ(dst_dtype_info::MIN, min_8bit); \ + CHECK_EQ(dst_dtype_info::MAX, max_8bit); \ + \ + if (is_column_major) { \ + if (scale_ue8m0) { \ + if (fuse_silu_and_mul) { \ + if (masked_layout) { \ + LAUNCH_KERNEL_INNER( \ + MaskedLayoutScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \ + } else { \ + LAUNCH_KERNEL_INNER( \ + NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \ + } \ + } else { \ + LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true); \ + } \ + } else { \ + LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, true); \ + } \ + } else { \ + LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, false); \ + } \ + } while (0) + +#define LAUNCH_KERNEL_OUTER(...) \ + switch (group_size) { \ + case 16: \ + LAUNCH_KERNEL(16, __VA_ARGS__); \ + break; \ + case 32: \ + LAUNCH_KERNEL(32, __VA_ARGS__); \ + break; \ + case 64: \ + LAUNCH_KERNEL(64, __VA_ARGS__); \ + break; \ + case 128: \ + LAUNCH_KERNEL(128, __VA_ARGS__); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported group_size"); \ + } \ + while (0) + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), scalar_t, [&] { if (dst_type == at::ScalarType::Char) { - LAUNCH_KERNEL(scalar_t, int8_t); + LAUNCH_KERNEL_OUTER(scalar_t, int8_t); return true; } else if (dst_type == at::ScalarType::Float8_e4m3fn) { - LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3); + LAUNCH_KERNEL_OUTER(scalar_t, c10::Float8_e4m3fn); return true; } return false; }); #undef LAUNCH_KERNEL -} - -void sgl_per_token_group_quant_int8( - torch::Tensor input, - torch::Tensor output_q, - torch::Tensor output_s, - int64_t group_size, - double eps, - double int8_min, - double int8_max) { - sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max); -} - -void sgl_per_token_group_quant_fp8( - torch::Tensor input, - torch::Tensor output_q, - torch::Tensor output_s, - int64_t group_size, - double eps, - double fp8_min, - double fp8_max, - bool scale_ue8m0) { - sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); +#undef LAUNCH_KERNEL_INNER } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 28422ad18..b6c40c801 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm( const torch::Dtype& out_dtype); void scaled_fp4_quant( torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); -void sgl_per_token_group_quant_fp8( +void sgl_per_token_group_quant_8bit( at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, double eps, - double fp8_min, - double fp8_max, - bool scale_ue8m0); -void sgl_per_token_group_quant_int8( - at::Tensor input, - at::Tensor output_q, - at::Tensor output_s, - int64_t group_size, - double eps, - double int8_min, - double int8_max); + double min_8bit, + double max_8bit, + bool scale_ue8m0, + bool fuse_silu_and_mul, + const std::optional& masked_m); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void bmm_fp8( diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 05a62efaa..cf771d553 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -55,8 +55,7 @@ from sgl_kernel.gemm import ( scaled_fp4_grouped_quant, scaled_fp4_quant, sgl_per_tensor_quant_fp8, - sgl_per_token_group_quant_fp8, - sgl_per_token_group_quant_int8, + sgl_per_token_group_quant_8bit, sgl_per_token_quant_fp8, shuffle_rows, silu_and_mul_scaled_fp4_grouped_quant, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 36672877d..1a4c5d2d5 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -98,7 +98,7 @@ def dsv3_fused_a_gemm( return output -def sgl_per_token_group_quant_fp8( +def sgl_per_token_group_quant_8bit( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, @@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8( eps: float, fp8_min: float, fp8_max: float, - scale_ue8m0: bool, + scale_ue8m0: bool = False, + fuse_silu_and_mul: bool = False, + masked_m: Optional[torch.Tensor] = None, ) -> None: - torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( - input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 - ) - - -def sgl_per_token_group_quant_int8( - input: torch.Tensor, - output_q: torch.Tensor, - output_s: torch.Tensor, - group_size: int, - eps: float, - int8_min: float, - int8_max: float, -) -> None: - torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default( - input, output_q, output_s, group_size, eps, int8_min, int8_max + torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default( + input, + output_q, + output_s, + group_size, + eps, + fp8_min, + fp8_max, + scale_ue8m0, + fuse_silu_and_mul, + masked_m, ) diff --git a/sgl-kernel/python/sgl_kernel/test_utils.py b/sgl-kernel/python/sgl_kernel/test_utils.py new file mode 100644 index 000000000..ede113fd0 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/test_utils.py @@ -0,0 +1,125 @@ +import torch + + +def create_per_token_group_quant_test_data(num_tokens, hidden_dim, num_ranks, flags): + device = torch.device("cuda") + dtype = torch.bfloat16 + + seed = num_tokens * 10000 + hidden_dim + gen_cpu = torch.Generator(device="cpu") + gen_cpu.manual_seed(seed) + gen_cuda = torch.Generator(device="cuda") + gen_cuda.manual_seed(seed) + + if flags["fuse_silu_and_mul"]: + effective_hidden_dim = hidden_dim * 2 + else: + effective_hidden_dim = hidden_dim + del hidden_dim + + if (masked_layout_mode := flags["masked_layout_mode"]) is not None: + num_max_dispatch_tokens_per_rank = 768 + num_global_experts = 288 + num_local_experts, remainder = divmod(num_global_experts, num_ranks) + assert remainder == 0 + + # mimic DeepEP low_latency_dispatch output + x = torch.randn( + num_local_experts, + num_max_dispatch_tokens_per_rank * num_ranks, + effective_hidden_dim, + device=device, + dtype=dtype, + generator=gen_cuda, + ) + + if masked_layout_mode == "balanced": + masked_m = _compute_balanced_split(num_tokens, num_local_experts) + elif masked_layout_mode == "imbalanced": + masked_m = _compute_imbalanced_split( + num_tokens, num_local_experts, gen_cpu=gen_cpu + ) + elif masked_layout_mode == "extreme": + masked_m = torch.tensor( + [num_tokens] + [0] * (num_local_experts - 1), dtype=torch.int + ) + else: + raise NotImplementedError + print(f"{masked_layout_mode=} {masked_m=} {x.shape=}") + + masked_m = masked_m.to(device) + + return x, masked_m + else: + x = torch.randn( + num_tokens, + effective_hidden_dim, + device=device, + dtype=dtype, + generator=gen_cuda, + ) + x[torch.randn(x.shape, device=device, generator=gen_cuda) < 0.001] *= 10 + return x, None + + +def _compute_balanced_split(total: int, arr_len: int): + base = total // arr_len + remainder = total % arr_len + ans = [base + 1 if i < remainder else base for i in range(arr_len)] + assert sum(ans) == total + return torch.tensor(ans, dtype=torch.int) + + +def _compute_imbalanced_split( + total: int, arr_len: int, gen_cpu, dtype=torch.int +) -> list[int]: + # can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is + noise_raw = torch.rand(arr_len, generator=gen_cpu) ** 3 + + noise = noise_raw / noise_raw.sum() + ans = (noise * total).round().to(dtype) + + diff = total - ans.sum().item() + while diff != 0: + idx = torch.randint(0, arr_len, (1,), generator=gen_cpu).item() + if diff > 0: + ans[idx] += 1 + diff -= 1 + elif diff < 0 and ans[idx] > 0: + ans[idx] -= 1 + diff += 1 + + assert sum(ans) == total + return ans + + +def assert_all_close_or_tiny_diff(a: torch.Tensor, b: torch.Tensor): + assert (a.shape == b.shape) and ( + a.dtype == b.dtype + ), f"{a.shape=} {b.shape=} {a.dtype=} {b.dtype=}" + numel = a.numel() + + if a.dtype == torch.float8_e4m3fn: + a_u8 = a.view(torch.uint8) + b_u8 = b.view(torch.uint8) + diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs() + + count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item() + count_tiny_diff = (diff_u8 == 1).sum().item() + count_large_diff = (diff_u8 >= 2).sum().item() + elif a.dtype == torch.int8: + diff = (a.to(torch.int16) - a.to(torch.int16)).abs() + count_diff_sign = ((a >= 0) & (b < 0)).sum().item() + count_tiny_diff = (diff == 1).sum().item() + count_large_diff = (diff >= 2).sum().item() + else: + raise NotImplementedError + + assert ( + (count_diff_sign == 0) + and (count_large_diff == 0) + and ( + (count_tiny_diff / numel < 0.005) + or ((count_tiny_diff / numel < 0.04) and (numel <= 4096)) + ) + ), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=} {a=} {b=}" diff --git a/sgl-kernel/tests/test_per_token_group_quant_8bit.py b/sgl-kernel/tests/test_per_token_group_quant_8bit.py index 778d14d31..f47c78414 100644 --- a/sgl-kernel/tests/test_per_token_group_quant_8bit.py +++ b/sgl-kernel/tests/test_per_token_group_quant_8bit.py @@ -1,96 +1,199 @@ import itertools +import os +import time +from pathlib import Path import pytest import torch +from sgl_kernel.test_utils import ( + assert_all_close_or_tiny_diff, + create_per_token_group_quant_test_data, +) -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_8bit as triton_per_token_group_quant_8bit, ) from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit -from sglang.srt.layers.quantization.utils import assert_fp8_all_close -from sglang.srt.utils import is_hip +from sglang.srt.utils import get_bool_env_var, is_hip _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn +configs = list( + itertools.product( + [1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192], # num_tokens + [128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384], # hidden_dim + [16, 32, 64, 128], # group_size + [None], # num_ranks + [fp8_type_, torch.int8], # dtype + [ + dict( + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=False, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + ], + ) +) + list( + itertools.product( + [1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8], + # TODO support more + [2048], + [128], + [8, 16, 32, 48], + [fp8_type_], + [ + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="balanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="imbalanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="extreme", + ), + ], + ) +) + @pytest.mark.parametrize( - "num_tokens, hidden_dim, group_size, dst_dtype, flags", - list( - itertools.product( - [127, 128, 512, 1024, 4096, 8192], # num_tokens - [256, 512, 1024, 2048, 4096], # hidden_dim - [8, 16, 32, 64, 128], # group_size - # TODO test int8 - [fp8_type_], # dtype - [ - dict( - column_major_scales=False, - scale_tma_aligned=False, - scale_ue8m0=False, - ), - dict( - column_major_scales=True, - scale_tma_aligned=False, - scale_ue8m0=False, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=False, - ), - dict( - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=True, - ), - ], - ) - ), + "num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags", configs ) def test_per_token_group_quant_with_column_major( num_tokens, hidden_dim, group_size, + num_ranks, dst_dtype, flags, ): - if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)): - pytest.skip() - return - if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL: - pytest.skip("scale_ue8m0 only supported on Blackwell") + print( + f"{num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=}" + ) + + arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device()) + if flags["scale_ue8m0"] and (arch_major <= 9): + pytest.skip("Only Blackwell need ue8m0 fusion") return - x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + if (flags["scale_ue8m0"] and (group_size != 128)) or ( + (dst_dtype == torch.int8) and flags["column_major_scales"] + ): + pytest.skip() + return + + x, masked_m = create_per_token_group_quant_test_data( + num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags + ) + + # print("hack data!!!") + # x = torch.full_like(x, fill_value=100) execute_kwargs = dict( x=x, + masked_m=masked_m, group_size=group_size, eps=1e-10, dst_dtype=dst_dtype, - **flags, + **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]}, ) - x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs) - x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs) + def _postprocess(x_q, x_s): + if masked_m is not None: + print(f"Mask tokens after {masked_m} to be zero") + for i in range(len(masked_m)): + x_q[i, masked_m[i] :, :] = 0 + x_s[i, masked_m[i] :, :] = 0 + return x_q, x_s - # torch.set_printoptions(profile="full") - # print(f"{x_q_triton=}") - # print(f"{x_s_triton=}") - # print(f"{x_q_sglang=}") - # print(f"{x_s_sglang=}") - # torch.set_printoptions(profile="default") - - assert_fp8_all_close(x_q_triton, x_q_sglang) - torch.testing.assert_close( - x_s_triton.contiguous(), - x_s_sglang.contiguous(), - rtol=1e-3, - atol=1e-5, - msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}", + x_q_triton, x_s_triton = _postprocess( + *triton_per_token_group_quant_8bit(**execute_kwargs) ) + x_q_sglang, x_s_sglang = _postprocess( + *sglang_per_token_group_quant_8bit(**execute_kwargs) + ) + + try: + assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang) + torch.testing.assert_close( + x_s_triton.contiguous(), + x_s_sglang.contiguous(), + rtol=1e-3, + atol=1e-5, + msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}", + ) + except AssertionError: + # torch.set_printoptions(profile="full") + print( + f"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}" + ) + print(f"{x=}") + print(f"{masked_m=}") + print(f"{x_q_triton=}") + print(f"{x_s_triton=}") + print(f"{x_q_sglang=}") + print(f"{x_s_sglang=}") + # torch.set_printoptions(profile="default") + + # if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "": + # import matplotlib.pyplot as plt + # + # base_stem = time.time() + # for name, value in [ + # ("x_q", x_q_triton != x_q_sglang), + # ("x_s", x_s_triton != x_s_sglang), + # ]: + # value = value.reshape((-1, value.shape[-1])) + # plt.figure(figsize=(20, 20)) + # plt.imshow((value * 1.0).cpu().numpy()) + # p = Path(d) / f"{base_stem}_{name}.png" + # print(f"Write diff to {p}", flush=True) + # plt.savefig(p) + + raise if __name__ == "__main__":