diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py similarity index 62% rename from sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py rename to sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py index 0d90b51b3..b03369c1d 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -4,7 +4,7 @@ from typing import Tuple import torch import triton import triton.language as tl -from sgl_kernel import sgl_per_token_group_quant_fp8 +from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8 from sglang.srt.utils import is_hip @@ -13,7 +13,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn @triton.jit -def _per_token_group_quant_fp8( +def _per_token_group_quant_8bit( # Pointers to inputs and output y_ptr, y_q_ptr, @@ -24,16 +24,15 @@ def _per_token_group_quant_fp8( N, # Avoid to divide zero eps, - # Information for float8 - fp8_min, - fp8_max, + # Information for 8bit data type (int8 or fp8_type_) + max_8bit, + min_8bit, # Meta-parameters BLOCK: tl.constexpr, ): """A Triton-accelerated function to perform per-token-group quantization on a tensor. - - This function converts the tensor values into float8 values. + This function converts the tensor values into 8bit values. """ # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) @@ -47,30 +46,27 @@ def _per_token_group_quant_fp8( y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max - y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + y_s = _absmax / max_8bit + y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) -def triton_per_token_group_quant_fp8( +def triton_per_token_group_quant_8bit( x: torch.Tensor, group_size: int, + dst_dtype: torch.dtype, eps: float = 1e-10, - dtype: torch.dtype = fp8_type_, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. - It converts the tensor values into signed float8 values and returns the quantized tensor along with the scaling factor used for quantization. - Args: x: The input tenosr with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. - Returns: Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ @@ -79,12 +75,16 @@ def triton_per_token_group_quant_fp8( ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - finfo = torch.finfo(dtype) - fp8_max = finfo.max + if dst_dtype == torch.int8: + iinfo = torch.iinfo(dst_dtype) + max_8bit = iinfo.max + min_8bit = iinfo.min + else: + finfo = torch.finfo(dst_dtype) + max_8bit = finfo.max + min_8bit = finfo.min - fp8_min = -fp8_max - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) M = x.numel() // group_size N = group_size x_s = torch.empty( @@ -97,15 +97,15 @@ def triton_per_token_group_quant_fp8( # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_fp8[(M,)]( + _per_token_group_quant_8bit[(M,)]( x, x_q, x_s, group_size, N, eps, - fp8_min=fp8_min, - fp8_max=fp8_max, + max_8bit, + min_8bit, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, @@ -114,50 +114,55 @@ def triton_per_token_group_quant_fp8( return x_q, x_s -def sglang_per_token_group_quant_fp8( +def sglang_per_token_group_quant_8bit( x: torch.Tensor, group_size: int, + dst_dtype: torch.dtype, eps: float = 1e-10, - dtype: torch.dtype = fp8_type_, ): assert ( x.shape[-1] % group_size == 0 ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - finfo = torch.finfo(dtype) - fp8_max = finfo.max - - fp8_min = -fp8_max - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) - M = x.numel() // group_size - N = group_size + x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) x_s = torch.empty( x.shape[:-1] + (x.shape[-1] // group_size,), device=x.device, dtype=torch.float32, ) - sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) + if dst_dtype == torch.int8: + iinfo = torch.iinfo(dst_dtype) + int8_max = iinfo.max + int8_min = iinfo.min + sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) + else: + f8_info = torch.finfo(dst_dtype) + fp8_max = f8_info.max + fp8_min = f8_info.min + sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) return x_q, x_s -def calculate_diff(batch_size, seq_len, group_size): - dtype = torch.float16 +def calculate_diff(batch_size, seq_len, group_size, dst_dtype): device = torch.device("cuda") hidden_dim = group_size * 2 - x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) - x_q_triton, x_s_triton = triton_per_token_group_quant_fp8(x.clone(), group_size) - x_q_sglang, x_s_sglang = sglang_per_token_group_quant_fp8(x.clone(), group_size) + x_q_triton, x_s_triton = triton_per_token_group_quant_8bit( + x.clone(), group_size, dst_dtype + ) + x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit( + x.clone(), group_size, dst_dtype + ) if torch.allclose( x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 ) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5): - print("✅ All implementations match") + print(f"✅ {dst_dtype} implementations match") else: print("❌ Implementations differ") @@ -165,36 +170,40 @@ def calculate_diff(batch_size, seq_len, group_size): batch_size_range = [1, 2, 4, 8, 16, 32, 64] seq_len_range = [64, 128, 256, 512, 1024, 2048] group_size_range = [128] # For DeepSeek V3/R1 +dst_dtype_range = [torch.int8, fp8_type_] -configs = list(itertools.product(batch_size_range, seq_len_range, group_size_range)) +configs = list( + itertools.product( + batch_size_range, seq_len_range, group_size_range, dst_dtype_range + ) +) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["batch_size", "seq_len", "group_size"], + x_names=["batch_size", "seq_len", "group_size", "dst_dtype"], x_vals=configs, line_arg="provider", line_vals=["triton", "sglang"], line_names=["Triton", "SGL Kernel"], styles=[("blue", "-"), ("green", "-")], ylabel="us", - plot_name="per-token-group-quant-fp8-performance", + plot_name="per-token-group-quant-8bit-performance", args={}, ) ) -def benchmark(batch_size, seq_len, group_size, provider): - dtype = torch.bfloat16 +def benchmark(batch_size, seq_len, group_size, dst_dtype, provider): device = torch.device("cuda") hidden_dim = 7168 - x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] if provider == "triton": - fn = lambda: triton_per_token_group_quant_fp8(x.clone(), group_size) + fn = lambda: triton_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype) elif provider == "sglang": - fn = lambda: sglang_per_token_group_quant_fp8(x.clone(), group_size) + fn = lambda: sglang_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype) ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) @@ -203,6 +212,7 @@ def benchmark(batch_size, seq_len, group_size, provider): if __name__ == "__main__": - calculate_diff(batch_size=4, seq_len=128, group_size=64) + calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=torch.int8) + calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=fp8_type_) benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu similarity index 59% rename from sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu rename to sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index cbf39f041..57a5ab8ad 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -6,8 +6,6 @@ #include "utils.h" -using FP8_TYPE = c10::Float8_e4m3fn; - __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { unsigned mask = 0xffff; @@ -18,27 +16,28 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { return val; } -template -__global__ void per_token_group_quant_fp8_kernel( +template +__global__ void per_token_group_quant_8bit_kernel( const T* __restrict__ input, void* __restrict__ output_q, float* __restrict__ output_s, const int group_size, const int num_groups, + const int groups_per_block, const float eps, - const float fp8_min, - const float fp8_max) { + const float min_8bit, + const float max_8bit) { const int threads_per_group = 16; const int local_group_id = threadIdx.x / threads_per_group; const int lane_id = threadIdx.x % threads_per_group; - const int block_group_id = blockIdx.x * GROUPS_PER_BLOCK; + const int block_group_id = blockIdx.x * groups_per_block; const int block_group_offset = (block_group_id + local_group_id) * group_size; float local_absmax = eps; const T* group_input = input + block_group_offset; - FP8_TYPE* group_output = static_cast(output_q) + block_group_offset; + DST_DTYPE* group_output = static_cast(output_q) + block_group_offset; float* scale_output = output_s + (block_group_id + local_group_id); constexpr uint32_t vec_size = 16 / sizeof(T); @@ -60,7 +59,7 @@ __global__ void per_token_group_quant_fp8_kernel( local_absmax = GroupReduceMax(local_absmax, lane_id); - const float y_s = local_absmax / fp8_max; + const float y_s = local_absmax / max_8bit; if (lane_id == 0) { *scale_output = y_s; @@ -73,20 +72,20 @@ __global__ void per_token_group_quant_fp8_kernel( #pragma unroll for (uint32_t j = 0; j < vec_size; ++j) { float val = static_cast(input_vec[j]); - float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max); - group_output[i * vec_size + j] = FP8_TYPE(q_val); + float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit); + group_output[i * vec_size + j] = DST_DTYPE(q_val); } } } -void sgl_per_token_group_quant_fp8( +void sgl_per_token_group_quant_8bit( torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, int64_t group_size, double eps, - double fp8_min, - double fp8_max) { + double min_8bit, + double max_8bit) { CHECK_INPUT(input); CHECK_INPUT(output_q); CHECK_INPUT(output_s); @@ -111,36 +110,58 @@ void sgl_per_token_group_quant_fp8( groups_per_block = 2; } -#define LAUNCH_KERNEL(T, GPB) \ - do { \ - constexpr int GROUPS_PER_BLOCK = GPB; \ - dim3 grid((num_groups + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK); \ - dim3 block(GROUPS_PER_BLOCK* THREADS_PER_GROUP); \ - per_token_group_quant_fp8_kernel<<>>( \ - static_cast(input.data_ptr()), \ - output_q.data_ptr(), \ - static_cast(output_s.data_ptr()), \ - group_size, \ - num_groups, \ - (float)eps, \ - (float)fp8_min, \ - (float)fp8_max); \ + auto dst_type = output_q.scalar_type(); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + +#define LAUNCH_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit); \ } while (0) DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { - if (groups_per_block == 16) { - LAUNCH_KERNEL(scalar_t, 16); - } else if (groups_per_block == 8) { - LAUNCH_KERNEL(scalar_t, 8); - } else if (groups_per_block == 4) { - LAUNCH_KERNEL(scalar_t, 4); - } else if (groups_per_block == 2) { - LAUNCH_KERNEL(scalar_t, 2); - } else { - LAUNCH_KERNEL(scalar_t, 1); + if (dst_type == at::ScalarType::Char) { + LAUNCH_KERNEL(scalar_t, int8_t); + return true; + } else if (dst_type == at::ScalarType::Float8_e4m3fn) { + LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); + return true; } - return true; + return false; }); #undef LAUNCH_KERNEL } + +void sgl_per_token_group_quant_int8( + torch::Tensor input, + torch::Tensor output_q, + torch::Tensor output_s, + int64_t group_size, + double eps, + double int8_min, + double int8_max) { + sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max); +} + +void sgl_per_token_group_quant_fp8( + torch::Tensor input, + torch::Tensor output_q, + torch::Tensor output_s, + int64_t group_size, + double eps, + double fp8_min, + double fp8_max) { + sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max); +} diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index 29bf9427b..fe5e09734 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -98,6 +98,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { " float eps, float fp8_min, float fp8_max) -> ()"); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); + m.def( + "sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps, float int8_min, float int8_max) -> ()"); + m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8); + m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index cfb39e9d4..98bd7bdac 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -141,6 +141,14 @@ void sgl_per_token_group_quant_fp8( double eps, double fp8_min, double fp8_max); +void sgl_per_token_group_quant_int8( + at::Tensor input, + at::Tensor output_q, + at::Tensor output_s, + int64_t group_size, + double eps, + double int8_min, + double int8_max); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void cublas_grouped_gemm( diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 36160f0e9..fbc4ac675 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -31,6 +31,7 @@ from sgl_kernel.gemm import ( int8_scaled_mm, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, + sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, ) from sgl_kernel.moe import moe_align_block_size, topk_softmax diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 65bb3fff5..bab9e3c8c 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -96,6 +96,20 @@ def sgl_per_token_group_quant_fp8( ) +def sgl_per_token_group_quant_int8( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, + group_size: int, + eps: float, + int8_min: float, + int8_max: float, +) -> None: + torch.ops.sgl_kernel.sgl_per_token_group_quant_int8( + input, output_q, output_s, group_size, eps, int8_min, int8_max + ) + + def sgl_per_tensor_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 4805e29f7..a7b612220 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -153,7 +153,7 @@ sources = [ "csrc/gemm/fp8_gemm_kernel.cu", "csrc/gemm/fp8_blockwise_gemm_kernel.cu", "csrc/gemm/int8_gemm_kernel.cu", - "csrc/gemm/per_token_group_quant_fp8.cu", + "csrc/gemm/per_token_group_quant_8bit.cu", "csrc/gemm/per_token_quant_fp8.cu", "csrc/gemm/per_tensor_quant_fp8.cu", "csrc/moe/moe_align_kernel.cu", diff --git a/sgl-kernel/tests/test_per_token_group_quant_fp8.py b/sgl-kernel/tests/test_per_token_group_quant_8bit.py similarity index 65% rename from sgl-kernel/tests/test_per_token_group_quant_fp8.py rename to sgl-kernel/tests/test_per_token_group_quant_8bit.py index 9fa7c9bd1..b628a6a42 100644 --- a/sgl-kernel/tests/test_per_token_group_quant_fp8.py +++ b/sgl-kernel/tests/test_per_token_group_quant_8bit.py @@ -1,20 +1,20 @@ import itertools -from typing import Any, Dict, List, Optional, Tuple +from typing import Tuple import pytest import torch import triton import triton.language as tl -from sgl_kernel import sgl_per_token_group_quant_fp8 +from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8 -from sglang.srt.utils import get_device_core_count, get_device_name, is_hip +from sglang.srt.utils import is_hip is_hip_ = is_hip() fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn @triton.jit -def _per_token_group_quant_fp8( +def _per_token_group_quant_8bit( # Pointers to inputs and output y_ptr, y_q_ptr, @@ -25,16 +25,15 @@ def _per_token_group_quant_fp8( N, # Avoid to divide zero eps, - # Information for float8 - fp8_min, - fp8_max, + # Information for 8bit data type (int8 or fp8_type_) + max_8bit, + min_8bit, # Meta-parameters BLOCK: tl.constexpr, ): """A Triton-accelerated function to perform per-token-group quantization on a tensor. - - This function converts the tensor values into float8 values. + This function converts the tensor values into 8bit values. """ # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) @@ -48,30 +47,27 @@ def _per_token_group_quant_fp8( y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max - y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + y_s = _absmax / max_8bit + y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) -def triton_per_token_group_quant_fp8( +def triton_per_token_group_quant_8bit( x: torch.Tensor, group_size: int, + dst_dtype: torch.dtype, eps: float = 1e-10, - dtype: torch.dtype = fp8_type_, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. - It converts the tensor values into signed float8 values and returns the quantized tensor along with the scaling factor used for quantization. - Args: x: The input tenosr with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. - Returns: Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ @@ -80,12 +76,16 @@ def triton_per_token_group_quant_fp8( ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - finfo = torch.finfo(dtype) - fp8_max = finfo.max + if dst_dtype == torch.int8: + iinfo = torch.iinfo(dst_dtype) + max_8bit = iinfo.max + min_8bit = iinfo.min + else: + finfo = torch.finfo(dst_dtype) + max_8bit = finfo.max + min_8bit = finfo.min - fp8_min = -fp8_max - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) M = x.numel() // group_size N = group_size x_s = torch.empty( @@ -98,15 +98,15 @@ def triton_per_token_group_quant_fp8( # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_fp8[(M,)]( + _per_token_group_quant_8bit[(M,)]( x, x_q, x_s, group_size, N, eps, - fp8_min=fp8_min, - fp8_max=fp8_max, + max_8bit, + min_8bit, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, @@ -115,53 +115,58 @@ def triton_per_token_group_quant_fp8( return x_q, x_s -def sglang_per_token_group_quant_fp8( +def sglang_per_token_group_quant_8bit( x: torch.Tensor, group_size: int, + dst_dtype: torch.dtype, eps: float = 1e-10, - dtype: torch.dtype = fp8_type_, ): assert ( x.shape[-1] % group_size == 0 ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - finfo = torch.finfo(dtype) - fp8_max = finfo.max - - fp8_min = -fp8_max - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) - M = x.numel() // group_size - N = group_size + x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) x_s = torch.empty( x.shape[:-1] + (x.shape[-1] // group_size,), device=x.device, dtype=torch.float32, ) - sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) + if dst_dtype == torch.int8: + iinfo = torch.iinfo(dst_dtype) + int8_max = iinfo.max + int8_min = iinfo.min + sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) + else: + f8_info = torch.finfo(dst_dtype) + fp8_max = f8_info.max + fp8_min = f8_info.min + sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) return x_q, x_s @pytest.mark.parametrize( - "batch_size, seq_len, group_size", + "batch_size, seq_len, group_size, dst_dtype", list( itertools.product( [1, 2, 4, 8, 16, 32, 64, 128], # batch_size [64, 128, 256, 512, 1024, 2048], # seq_len [16, 32, 64, 128, 256], # group_size + [torch.int8, fp8_type_], # dtype ) ), ) -def test_per_token_group_quant_compare_implementations(batch_size, seq_len, group_size): +def test_per_token_group_quant_compare_implementations( + batch_size, seq_len, group_size, dst_dtype +): x = torch.randn( (batch_size, seq_len, group_size * 2), device="cuda", dtype=torch.float16 ) - x_q_triton, x_s_triton = triton_per_token_group_quant_fp8(x, group_size) - x_q_sglang, x_s_sglang = sglang_per_token_group_quant_fp8(x, group_size) + x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(x, group_size, dst_dtype) + x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(x, group_size, dst_dtype) assert torch.allclose( x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5