diff --git a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu index 25b57c8f4..d818ddfb8 100644 --- a/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu +++ b/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -32,7 +33,7 @@ __global__ void per_token_group_quant_8bit_kernel( const float eps, const float min_8bit, const float max_8bit, - const int scale_num_rows = 0, + const int num_groups_per_row = 0, const int scale_stride = 0) { const int threads_per_group = 16; const int64_t local_group_id = threadIdx.x / threads_per_group; @@ -53,11 +54,10 @@ __global__ void per_token_group_quant_8bit_kernel( if constexpr (IS_COLUMN_MAJOR) { const int num_elems_per_pack = static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); - const int scale_num_rows_element = scale_num_rows * num_elems_per_pack; - const int row_idx = global_group_id / scale_num_rows_element; - const int col_idx_raw = global_group_id % scale_num_rows_element; - const int col_idx = col_idx_raw / num_elems_per_pack; - const int pack_idx = col_idx_raw % num_elems_per_pack; + const int row_idx = global_group_id / num_groups_per_row; + const int col_idx_unpacked = global_group_id % num_groups_per_row; + const int col_idx = col_idx_unpacked / num_elems_per_pack; + const int pack_idx = col_idx_unpacked % num_elems_per_pack; scale_output = reinterpret_cast(output_s) + (col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx); } else { @@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel( float y_s = local_absmax / max_8bit; if constexpr (SCALE_UE8M0) { - y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f)))); } // TODO can optimize @@ -152,7 +152,8 @@ void sgl_per_token_group_quant_8bit( const int num_threads = groups_per_block * THREADS_PER_GROUP; const bool is_column_major = output_s.stride(0) < output_s.stride(1); - const int scale_num_rows = output_s.size(1); + const int hidden_dim = input.size(input.dim() - 1); + const int num_groups_per_row = hidden_dim / group_size; const int scale_stride = output_s.stride(1); #define LAUNCH_KERNEL(T, DST_DTYPE) \ @@ -171,7 +172,7 @@ void sgl_per_token_group_quant_8bit( (float)eps, \ (float)min_8bit, \ (float)max_8bit, \ - scale_num_rows, \ + num_groups_per_row, \ scale_stride); \ } else { \ per_token_group_quant_8bit_kernel<<>>( \ @@ -184,7 +185,7 @@ void sgl_per_token_group_quant_8bit( (float)eps, \ (float)min_8bit, \ (float)max_8bit, \ - scale_num_rows, \ + num_groups_per_row, \ scale_stride); \ } \ } else { \ @@ -207,7 +208,7 @@ void sgl_per_token_group_quant_8bit( LAUNCH_KERNEL(scalar_t, int8_t); return true; } else if (dst_type == at::ScalarType::Float8_e4m3fn) { - LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); + LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3); return true; } return false;