Support new DeepGEMM format in per token group quant (#7146)
This commit is contained in:
@@ -116,7 +116,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def(
|
||||
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float fp8_min, float fp8_max) -> ()");
|
||||
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
|
||||
|
||||
m.def(
|
||||
|
||||
@@ -16,11 +16,16 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false>
|
||||
template <
|
||||
typename T,
|
||||
typename DST_DTYPE,
|
||||
bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false,
|
||||
typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
const T* __restrict__ input,
|
||||
void* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
scale_packed_t* __restrict__ output_s,
|
||||
const int group_size,
|
||||
const int num_groups,
|
||||
const int groups_per_block,
|
||||
@@ -39,15 +44,24 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
float* scale_output;
|
||||
scale_element_t* scale_output;
|
||||
|
||||
if constexpr (IS_COLUMN_MAJOR) {
|
||||
const int row_idx = global_group_id / scale_num_rows;
|
||||
const int col_idx = global_group_id % scale_num_rows;
|
||||
scale_output = output_s + (col_idx * scale_stride + row_idx);
|
||||
const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
||||
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
|
||||
const int row_idx = global_group_id / scale_num_rows_element;
|
||||
const int col_idx_raw = global_group_id % scale_num_rows_element;
|
||||
const int col_idx = col_idx_raw / num_elems_per_pack;
|
||||
const int pack_idx = col_idx_raw % num_elems_per_pack;
|
||||
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
||||
(col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
|
||||
} else {
|
||||
static_assert(!SCALE_UE8M0);
|
||||
scale_output = output_s + global_group_id;
|
||||
}
|
||||
|
||||
@@ -70,10 +84,21 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax, lane_id);
|
||||
|
||||
const float y_s = local_absmax / max_8bit;
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
|
||||
// TODO can optimize
|
||||
scale_element_t y_s_quant;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127);
|
||||
} else {
|
||||
y_s_quant = y_s;
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
*scale_output = y_s;
|
||||
*scale_output = y_s_quant;
|
||||
}
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
|
||||
@@ -96,7 +121,8 @@ void sgl_per_token_group_quant_8bit(
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double min_8bit,
|
||||
double max_8bit) {
|
||||
double max_8bit,
|
||||
bool scale_ue8m0 = false) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
|
||||
@@ -129,35 +155,51 @@ void sgl_per_token_group_quant_8bit(
|
||||
const int scale_num_rows = output_s.size(1);
|
||||
const int scale_stride = output_s.stride(1);
|
||||
|
||||
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
if (is_column_major) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit, \
|
||||
scale_num_rows, \
|
||||
scale_stride); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} \
|
||||
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
if (is_column_major) { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<uint32_t*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit, \
|
||||
scale_num_rows, \
|
||||
scale_stride); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit, \
|
||||
scale_num_rows, \
|
||||
scale_stride); \
|
||||
} \
|
||||
} else { \
|
||||
assert(!scale_ue8m0); \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
@@ -192,6 +234,7 @@ void sgl_per_token_group_quant_fp8(
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max) {
|
||||
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max);
|
||||
double fp8_max,
|
||||
bool scale_ue8m0) {
|
||||
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
||||
|
||||
@@ -175,7 +175,8 @@ void sgl_per_token_group_quant_fp8(
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max);
|
||||
double fp8_max,
|
||||
bool scale_ue8m0);
|
||||
void sgl_per_token_group_quant_int8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
|
||||
@@ -90,9 +90,10 @@ def sgl_per_token_group_quant_fp8(
|
||||
eps: float,
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
scale_ue8m0: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -255,7 +255,10 @@ def sglang_per_token_group_quant_8bit(
|
||||
f8_info = torch.finfo(dtype)
|
||||
fp8_max = f8_info.max
|
||||
fp8_min = f8_info.min
|
||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||
scale_ue8m0 = False # TODO also test true
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
Reference in New Issue
Block a user