[Quant Kernel] refactored per token group quant fp8 to support int8 up-to 2x faster (#4396)

This commit is contained in:
Chunan Zeng
2025-03-23 23:44:17 -07:00
committed by GitHub
parent 3980ff1be6
commit 65c24c28f9
8 changed files with 191 additions and 127 deletions

View File

@@ -6,8 +6,6 @@
#include "utils.h"
using FP8_TYPE = c10::Float8_e4m3fn;
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;
@@ -18,27 +16,28 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return val;
}
template <typename T, int GROUPS_PER_BLOCK = 16>
__global__ void per_token_group_quant_fp8_kernel(
template <typename T, typename DST_DTYPE>
__global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input,
void* __restrict__ output_q,
float* __restrict__ output_s,
const int group_size,
const int num_groups,
const int groups_per_block,
const float eps,
const float fp8_min,
const float fp8_max) {
const float min_8bit,
const float max_8bit) {
const int threads_per_group = 16;
const int local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;
const int block_group_id = blockIdx.x * GROUPS_PER_BLOCK;
const int block_group_id = blockIdx.x * groups_per_block;
const int block_group_offset = (block_group_id + local_group_id) * group_size;
float local_absmax = eps;
const T* group_input = input + block_group_offset;
FP8_TYPE* group_output = static_cast<FP8_TYPE*>(output_q) + block_group_offset;
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
float* scale_output = output_s + (block_group_id + local_group_id);
constexpr uint32_t vec_size = 16 / sizeof(T);
@@ -60,7 +59,7 @@ __global__ void per_token_group_quant_fp8_kernel(
local_absmax = GroupReduceMax(local_absmax, lane_id);
const float y_s = local_absmax / fp8_max;
const float y_s = local_absmax / max_8bit;
if (lane_id == 0) {
*scale_output = y_s;
@@ -73,20 +72,20 @@ __global__ void per_token_group_quant_fp8_kernel(
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = static_cast<float>(input_vec[j]);
float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max);
group_output[i * vec_size + j] = FP8_TYPE(q_val);
float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit);
group_output[i * vec_size + j] = DST_DTYPE(q_val);
}
}
}
void sgl_per_token_group_quant_fp8(
void sgl_per_token_group_quant_8bit(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max) {
double min_8bit,
double max_8bit) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
@@ -111,36 +110,58 @@ void sgl_per_token_group_quant_fp8(
groups_per_block = 2;
}
#define LAUNCH_KERNEL(T, GPB) \
do { \
constexpr int GROUPS_PER_BLOCK = GPB; \
dim3 grid((num_groups + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK); \
dim3 block(GROUPS_PER_BLOCK* THREADS_PER_GROUP); \
per_token_group_quant_fp8_kernel<T, GROUPS_PER_BLOCK><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
(float)eps, \
(float)fp8_min, \
(float)fp8_max); \
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
if (groups_per_block == 16) {
LAUNCH_KERNEL(scalar_t, 16);
} else if (groups_per_block == 8) {
LAUNCH_KERNEL(scalar_t, 8);
} else if (groups_per_block == 4) {
LAUNCH_KERNEL(scalar_t, 4);
} else if (groups_per_block == 2) {
LAUNCH_KERNEL(scalar_t, 2);
} else {
LAUNCH_KERNEL(scalar_t, 1);
if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t);
return true;
} else if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
return true;
}
return true;
return false;
});
#undef LAUNCH_KERNEL
}
void sgl_per_token_group_quant_int8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
}
void sgl_per_token_group_quant_fp8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max);
}

View File

@@ -98,6 +98,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
" float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);