Support new DeepGEMM format in per token group quant (#7146)
This commit is contained in:
@@ -116,7 +116,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||||
" float eps, float fp8_min, float fp8_max) -> ()");
|
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()");
|
||||||
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
|
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
|||||||
@@ -16,11 +16,16 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
|||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false>
|
template <
|
||||||
|
typename T,
|
||||||
|
typename DST_DTYPE,
|
||||||
|
bool IS_COLUMN_MAJOR = false,
|
||||||
|
bool SCALE_UE8M0 = false,
|
||||||
|
typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>>
|
||||||
__global__ void per_token_group_quant_8bit_kernel(
|
__global__ void per_token_group_quant_8bit_kernel(
|
||||||
const T* __restrict__ input,
|
const T* __restrict__ input,
|
||||||
void* __restrict__ output_q,
|
void* __restrict__ output_q,
|
||||||
float* __restrict__ output_s,
|
scale_packed_t* __restrict__ output_s,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const int num_groups,
|
const int num_groups,
|
||||||
const int groups_per_block,
|
const int groups_per_block,
|
||||||
@@ -39,15 +44,24 @@ __global__ void per_token_group_quant_8bit_kernel(
|
|||||||
|
|
||||||
float local_absmax = eps;
|
float local_absmax = eps;
|
||||||
|
|
||||||
|
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>;
|
||||||
|
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||||
|
|
||||||
const T* group_input = input + block_group_offset;
|
const T* group_input = input + block_group_offset;
|
||||||
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||||
float* scale_output;
|
scale_element_t* scale_output;
|
||||||
|
|
||||||
if constexpr (IS_COLUMN_MAJOR) {
|
if constexpr (IS_COLUMN_MAJOR) {
|
||||||
const int row_idx = global_group_id / scale_num_rows;
|
const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
||||||
const int col_idx = global_group_id % scale_num_rows;
|
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
|
||||||
scale_output = output_s + (col_idx * scale_stride + row_idx);
|
const int row_idx = global_group_id / scale_num_rows_element;
|
||||||
|
const int col_idx_raw = global_group_id % scale_num_rows_element;
|
||||||
|
const int col_idx = col_idx_raw / num_elems_per_pack;
|
||||||
|
const int pack_idx = col_idx_raw % num_elems_per_pack;
|
||||||
|
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
||||||
|
(col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
|
||||||
} else {
|
} else {
|
||||||
|
static_assert(!SCALE_UE8M0);
|
||||||
scale_output = output_s + global_group_id;
|
scale_output = output_s + global_group_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,10 +84,21 @@ __global__ void per_token_group_quant_8bit_kernel(
|
|||||||
|
|
||||||
local_absmax = GroupReduceMax(local_absmax, lane_id);
|
local_absmax = GroupReduceMax(local_absmax, lane_id);
|
||||||
|
|
||||||
const float y_s = local_absmax / max_8bit;
|
float y_s = local_absmax / max_8bit;
|
||||||
|
if constexpr (SCALE_UE8M0) {
|
||||||
|
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO can optimize
|
||||||
|
scale_element_t y_s_quant;
|
||||||
|
if constexpr (SCALE_UE8M0) {
|
||||||
|
y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127);
|
||||||
|
} else {
|
||||||
|
y_s_quant = y_s;
|
||||||
|
}
|
||||||
|
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
*scale_output = y_s;
|
*scale_output = y_s_quant;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
|
for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
|
||||||
@@ -96,7 +121,8 @@ void sgl_per_token_group_quant_8bit(
|
|||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
double eps,
|
double eps,
|
||||||
double min_8bit,
|
double min_8bit,
|
||||||
double max_8bit) {
|
double max_8bit,
|
||||||
|
bool scale_ue8m0 = false) {
|
||||||
CHECK_INPUT(input);
|
CHECK_INPUT(input);
|
||||||
CHECK_INPUT(output_q);
|
CHECK_INPUT(output_q);
|
||||||
|
|
||||||
@@ -129,35 +155,51 @@ void sgl_per_token_group_quant_8bit(
|
|||||||
const int scale_num_rows = output_s.size(1);
|
const int scale_num_rows = output_s.size(1);
|
||||||
const int scale_stride = output_s.stride(1);
|
const int scale_stride = output_s.stride(1);
|
||||||
|
|
||||||
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
||||||
do { \
|
do { \
|
||||||
dim3 grid(num_blocks); \
|
dim3 grid(num_blocks); \
|
||||||
dim3 block(num_threads); \
|
dim3 block(num_threads); \
|
||||||
if (is_column_major) { \
|
if (is_column_major) { \
|
||||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>( \
|
if (scale_ue8m0) { \
|
||||||
static_cast<T*>(input.data_ptr()), \
|
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
|
||||||
output_q.data_ptr(), \
|
static_cast<T*>(input.data_ptr()), \
|
||||||
static_cast<float*>(output_s.data_ptr()), \
|
output_q.data_ptr(), \
|
||||||
group_size, \
|
static_cast<uint32_t*>(output_s.data_ptr()), \
|
||||||
num_groups, \
|
group_size, \
|
||||||
groups_per_block, \
|
num_groups, \
|
||||||
(float)eps, \
|
groups_per_block, \
|
||||||
(float)min_8bit, \
|
(float)eps, \
|
||||||
(float)max_8bit, \
|
(float)min_8bit, \
|
||||||
scale_num_rows, \
|
(float)max_8bit, \
|
||||||
scale_stride); \
|
scale_num_rows, \
|
||||||
} else { \
|
scale_stride); \
|
||||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
|
} else { \
|
||||||
static_cast<T*>(input.data_ptr()), \
|
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
|
||||||
output_q.data_ptr(), \
|
static_cast<T*>(input.data_ptr()), \
|
||||||
static_cast<float*>(output_s.data_ptr()), \
|
output_q.data_ptr(), \
|
||||||
group_size, \
|
static_cast<float*>(output_s.data_ptr()), \
|
||||||
num_groups, \
|
group_size, \
|
||||||
groups_per_block, \
|
num_groups, \
|
||||||
(float)eps, \
|
groups_per_block, \
|
||||||
(float)min_8bit, \
|
(float)eps, \
|
||||||
(float)max_8bit); \
|
(float)min_8bit, \
|
||||||
} \
|
(float)max_8bit, \
|
||||||
|
scale_num_rows, \
|
||||||
|
scale_stride); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
assert(!scale_ue8m0); \
|
||||||
|
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
|
||||||
|
static_cast<T*>(input.data_ptr()), \
|
||||||
|
output_q.data_ptr(), \
|
||||||
|
static_cast<float*>(output_s.data_ptr()), \
|
||||||
|
group_size, \
|
||||||
|
num_groups, \
|
||||||
|
groups_per_block, \
|
||||||
|
(float)eps, \
|
||||||
|
(float)min_8bit, \
|
||||||
|
(float)max_8bit); \
|
||||||
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||||
@@ -192,6 +234,7 @@ void sgl_per_token_group_quant_fp8(
|
|||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
double eps,
|
double eps,
|
||||||
double fp8_min,
|
double fp8_min,
|
||||||
double fp8_max) {
|
double fp8_max,
|
||||||
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max);
|
bool scale_ue8m0) {
|
||||||
|
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -175,7 +175,8 @@ void sgl_per_token_group_quant_fp8(
|
|||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
double eps,
|
double eps,
|
||||||
double fp8_min,
|
double fp8_min,
|
||||||
double fp8_max);
|
double fp8_max,
|
||||||
|
bool scale_ue8m0);
|
||||||
void sgl_per_token_group_quant_int8(
|
void sgl_per_token_group_quant_int8(
|
||||||
at::Tensor input,
|
at::Tensor input,
|
||||||
at::Tensor output_q,
|
at::Tensor output_q,
|
||||||
|
|||||||
@@ -90,9 +90,10 @@ def sgl_per_token_group_quant_fp8(
|
|||||||
eps: float,
|
eps: float,
|
||||||
fp8_min: float,
|
fp8_min: float,
|
||||||
fp8_max: float,
|
fp8_max: float,
|
||||||
|
scale_ue8m0: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
||||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -255,7 +255,10 @@ def sglang_per_token_group_quant_8bit(
|
|||||||
f8_info = torch.finfo(dtype)
|
f8_info = torch.finfo(dtype)
|
||||||
fp8_max = f8_info.max
|
fp8_max = f8_info.max
|
||||||
fp8_min = f8_info.min
|
fp8_min = f8_info.min
|
||||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
scale_ue8m0 = False # TODO also test true
|
||||||
|
sgl_per_token_group_quant_fp8(
|
||||||
|
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||||
|
)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user