Speed up per token and per tensor quant by 15% (#4639)
This commit is contained in:
@@ -54,42 +54,40 @@ __global__ void per_tensor_quant_fp8_kernel(
|
|||||||
const int grid_size = blockDim.x * gridDim.x;
|
const int grid_size = blockDim.x * gridDim.x;
|
||||||
const float scale_val = 1.0f / (*scale);
|
const float scale_val = 1.0f / (*scale);
|
||||||
|
|
||||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
// We want to store 128 bits of data at a time. 16 = 128 / 8 bits
|
||||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
// Load is already vectorized, so 16 elements work for T.
|
||||||
|
const uint32_t VEC_SIZE = 16;
|
||||||
|
using vec_t = flashinfer::vec_t<T, VEC_SIZE>;
|
||||||
|
|
||||||
const int32_t num_vec_elems = num_elements / vec_size;
|
const int32_t num_vec_elems = num_elements / VEC_SIZE;
|
||||||
|
|
||||||
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
|
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
|
||||||
vec_t input_vec;
|
vec_t input_vec;
|
||||||
input_vec.cast_load(input + i * vec_size);
|
input_vec.cast_load(input + i * VEC_SIZE);
|
||||||
|
|
||||||
FP8_TYPE output_arr[vec_size];
|
FP8_TYPE output_arr[VEC_SIZE];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
output_arr[j] = static_cast<FP8_TYPE>(val);
|
output_arr[j] = static_cast<FP8_TYPE>(val);
|
||||||
#else
|
#else
|
||||||
output_arr[j] = c10::Float8_e4m3fnuz(
|
output_arr[j] = c10::Float8_e4m3fnuz(
|
||||||
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||||
c10::Float8_e4m3fnuz::from_bits());
|
c10::Float8_e4m3fnuz::from_bits());
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
*(uint4*)(output + i * VEC_SIZE) = *(uint4*)output_arr;
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
|
||||||
output[i * vec_size + j] = output_arr[j];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int32_t remaining_start = num_vec_elems * vec_size;
|
const int32_t remaining_start = num_vec_elems * VEC_SIZE;
|
||||||
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
|
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
|
||||||
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
|
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
output[idx] = static_cast<FP8_TYPE>(val);
|
output[idx] = static_cast<FP8_TYPE>(val);
|
||||||
#else
|
#else
|
||||||
output[idx] = c10::Float8_e4m3fnuz(
|
output[idx] = c10::Float8_e4m3fnuz(
|
||||||
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
|
||||||
c10::Float8_e4m3fnuz::from_bits());
|
c10::Float8_e4m3fnuz::from_bits());
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,17 +24,19 @@ __global__ void per_token_quant_fp8_kernel(
|
|||||||
|
|
||||||
float max_value = 0.0f;
|
float max_value = 0.0f;
|
||||||
|
|
||||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
// We want to store 128 bits of data at a time. 16 = 128 / 8 bits
|
||||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
// Load is already vectorized, so 16 elements work for T.
|
||||||
const int32_t num_vec_elems = hidden_dim / vec_size;
|
const uint32_t VEC_SIZE = 16;
|
||||||
|
using vec_t = flashinfer::vec_t<T, VEC_SIZE>;
|
||||||
|
const int32_t num_vec_elems = hidden_dim / VEC_SIZE;
|
||||||
|
|
||||||
// Find max using vectorized loads
|
// Find max using vectorized loads
|
||||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||||
vec_t input_vec;
|
vec_t input_vec;
|
||||||
input_vec.cast_load(token_input + i * vec_size);
|
input_vec.cast_load(token_input + i * VEC_SIZE);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
float val = static_cast<float>(input_vec[j]);
|
float val = static_cast<float>(input_vec[j]);
|
||||||
max_value = fmaxf(max_value, fabsf(val));
|
max_value = fmaxf(max_value, fabsf(val));
|
||||||
}
|
}
|
||||||
@@ -42,24 +44,24 @@ __global__ void per_token_quant_fp8_kernel(
|
|||||||
|
|
||||||
max_value = blockReduceMax(max_value);
|
max_value = blockReduceMax(max_value);
|
||||||
|
|
||||||
__shared__ float block_max;
|
__shared__ float scale;
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
block_max = max_value / FP8_E4M3_MAX;
|
scale = max_value / FP8_E4M3_MAX;
|
||||||
output_s[token_idx] = block_max;
|
output_s[token_idx] = scale;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const float scale_val = 1.0f / block_max;
|
const float scale_inv = 1.0f / scale;
|
||||||
|
|
||||||
// Quantize using vectorized loads
|
// Quantize using vectorized loads
|
||||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||||
vec_t input_vec;
|
vec_t input_vec;
|
||||||
input_vec.cast_load(token_input + i * vec_size);
|
input_vec.cast_load(token_input + i * VEC_SIZE);
|
||||||
|
|
||||||
FP8_TYPE output_arr[vec_size];
|
FP8_TYPE output_arr[VEC_SIZE];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
output_arr[j] = static_cast<FP8_TYPE>(val);
|
output_arr[j] = static_cast<FP8_TYPE>(val);
|
||||||
#else
|
#else
|
||||||
@@ -69,10 +71,7 @@ __global__ void per_token_quant_fp8_kernel(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
*(uint4*)(token_output + i * VEC_SIZE) = *(uint4*)output_arr;
|
||||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
|
||||||
token_output[i * vec_size + j] = output_arr[j];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +84,7 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
|
|||||||
const int64_t num_tokens = input_sizes[0];
|
const int64_t num_tokens = input_sizes[0];
|
||||||
const int64_t hidden_dim = input_sizes[1];
|
const int64_t hidden_dim = input_sizes[1];
|
||||||
|
|
||||||
TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
|
TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim);
|
||||||
|
|
||||||
const int block_size = 256;
|
const int block_size = 256;
|
||||||
const int num_blocks = num_tokens;
|
const int num_blocks = num_tokens;
|
||||||
|
|||||||
Reference in New Issue
Block a user