diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index d8e9fb336..9f3c2be9c 100755 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -116,7 +116,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," - " float eps, float fp8_min, float fp8_max) -> ()"); + " float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); m.def( diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index b374fd3e2..c9474b96e 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -16,11 +16,16 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { return val; } -template +template < + typename T, + typename DST_DTYPE, + bool IS_COLUMN_MAJOR = false, + bool SCALE_UE8M0 = false, + typename scale_packed_t = std::conditional_t> __global__ void per_token_group_quant_8bit_kernel( const T* __restrict__ input, void* __restrict__ output_q, - float* __restrict__ output_s, + scale_packed_t* __restrict__ output_s, const int group_size, const int num_groups, const int groups_per_block, @@ -39,15 +44,24 @@ __global__ void per_token_group_quant_8bit_kernel( float local_absmax = eps; + using scale_element_t = std::conditional_t; + static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); + const T* group_input = input + block_group_offset; DST_DTYPE* group_output = static_cast(output_q) + block_group_offset; - float* scale_output; + scale_element_t* scale_output; if constexpr (IS_COLUMN_MAJOR) { - const int row_idx = global_group_id / scale_num_rows; - const int col_idx = global_group_id % scale_num_rows; - scale_output = output_s + (col_idx * scale_stride + row_idx); + const int num_elems_per_pack = static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); + const int scale_num_rows_element = scale_num_rows * num_elems_per_pack; + const int row_idx = global_group_id / scale_num_rows_element; + const int col_idx_raw = global_group_id % scale_num_rows_element; + const int col_idx = col_idx_raw / num_elems_per_pack; + const int pack_idx = col_idx_raw % num_elems_per_pack; + scale_output = reinterpret_cast(output_s) + + (col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx); } else { + static_assert(!SCALE_UE8M0); scale_output = output_s + global_group_id; } @@ -70,10 +84,21 @@ __global__ void per_token_group_quant_8bit_kernel( local_absmax = GroupReduceMax(local_absmax, lane_id); - const float y_s = local_absmax / max_8bit; + float y_s = local_absmax / max_8bit; + if constexpr (SCALE_UE8M0) { + y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + } + + // TODO can optimize + scale_element_t y_s_quant; + if constexpr (SCALE_UE8M0) { + y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127); + } else { + y_s_quant = y_s; + } if (lane_id == 0) { - *scale_output = y_s; + *scale_output = y_s_quant; } for (int32_t i = lane_id; i < num_vec_elems; i += 16) { @@ -96,7 +121,8 @@ void sgl_per_token_group_quant_8bit( int64_t group_size, double eps, double min_8bit, - double max_8bit) { + double max_8bit, + bool scale_ue8m0 = false) { CHECK_INPUT(input); CHECK_INPUT(output_q); @@ -129,35 +155,51 @@ void sgl_per_token_group_quant_8bit( const int scale_num_rows = output_s.size(1); const int scale_stride = output_s.stride(1); -#define LAUNCH_KERNEL(T, DST_DTYPE) \ - do { \ - dim3 grid(num_blocks); \ - dim3 block(num_threads); \ - if (is_column_major) { \ - per_token_group_quant_8bit_kernel<<>>( \ - static_cast(input.data_ptr()), \ - output_q.data_ptr(), \ - static_cast(output_s.data_ptr()), \ - group_size, \ - num_groups, \ - groups_per_block, \ - (float)eps, \ - (float)min_8bit, \ - (float)max_8bit, \ - scale_num_rows, \ - scale_stride); \ - } else { \ - per_token_group_quant_8bit_kernel<<>>( \ - static_cast(input.data_ptr()), \ - output_q.data_ptr(), \ - static_cast(output_s.data_ptr()), \ - group_size, \ - num_groups, \ - groups_per_block, \ - (float)eps, \ - (float)min_8bit, \ - (float)max_8bit); \ - } \ +#define LAUNCH_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + if (is_column_major) { \ + if (scale_ue8m0) { \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit, \ + scale_num_rows, \ + scale_stride); \ + } else { \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit, \ + scale_num_rows, \ + scale_stride); \ + } \ + } else { \ + assert(!scale_ue8m0); \ + per_token_group_quant_8bit_kernel<<>>( \ + static_cast(input.data_ptr()), \ + output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), \ + group_size, \ + num_groups, \ + groups_per_block, \ + (float)eps, \ + (float)min_8bit, \ + (float)max_8bit); \ + } \ } while (0) DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { @@ -192,6 +234,7 @@ void sgl_per_token_group_quant_fp8( int64_t group_size, double eps, double fp8_min, - double fp8_max) { - sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max); + double fp8_max, + bool scale_ue8m0) { + sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 1fdfbeae1..1cc88afa0 100755 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -175,7 +175,8 @@ void sgl_per_token_group_quant_fp8( int64_t group_size, double eps, double fp8_min, - double fp8_max); + double fp8_max, + bool scale_ue8m0); void sgl_per_token_group_quant_int8( at::Tensor input, at::Tensor output_q, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 48a21ee8b..45a8f8134 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -90,9 +90,10 @@ def sgl_per_token_group_quant_fp8( eps: float, fp8_min: float, fp8_max: float, + scale_ue8m0: bool, ) -> None: torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( - input, output_q, output_s, group_size, eps, fp8_min, fp8_max + input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 ) diff --git a/sgl-kernel/tests/test_per_token_group_quant_8bit.py b/sgl-kernel/tests/test_per_token_group_quant_8bit.py index 66be47d28..31070d1cd 100644 --- a/sgl-kernel/tests/test_per_token_group_quant_8bit.py +++ b/sgl-kernel/tests/test_per_token_group_quant_8bit.py @@ -255,7 +255,10 @@ def sglang_per_token_group_quant_8bit( f8_info = torch.finfo(dtype) fp8_max = f8_info.max fp8_min = f8_info.min - sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) + scale_ue8m0 = False # TODO also test true + sgl_per_token_group_quant_fp8( + x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 + ) return x_q, x_s