fix per token cuda kernel hidden dim cannot divide by 16 (#8543)
This commit is contained in:
@@ -75,14 +75,21 @@ __global__ void per_token_quant_fp8_kernel(
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE>
|
||||
template <typename T, typename DST_DTYPE, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
@@ -100,19 +107,17 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
// We want to store 128 bits of data at a time. 16 = 128 / 8 bits
|
||||
// Load is already vectorized, so 16 elements work for T.
|
||||
const uint32_t VEC_SIZE = 16;
|
||||
using vec_t = flashinfer::vec_t<T, VEC_SIZE>;
|
||||
const int32_t num_vec_elems = hidden_dim / VEC_SIZE;
|
||||
// Use template parameter for vector size
|
||||
using vec_t = flashinfer::vec_t<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_dim / kVecSize;
|
||||
|
||||
// Find max using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * VEC_SIZE);
|
||||
input_vec.cast_load(token_input + i * kVecSize);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
@@ -132,11 +137,11 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
// Quantize using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(token_input + i * VEC_SIZE);
|
||||
input_vec.cast_load(token_input + i * kVecSize);
|
||||
|
||||
DST_DTYPE output_arr[VEC_SIZE];
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
#ifndef USE_ROCM
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
@@ -147,7 +152,14 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
*(uint4*)(token_output + i * VEC_SIZE) = *(uint4*)output_arr;
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,13 +170,14 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
|
||||
const auto input_sizes = input.sizes();
|
||||
const int64_t num_tokens = input_sizes[0];
|
||||
const int64_t hidden_dim = input_sizes[1];
|
||||
TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim);
|
||||
TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// Hard-code sm_count
|
||||
int sm_count = 132;
|
||||
constexpr int TOKENS_PER_CTA = 8;
|
||||
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
|
||||
const bool use_vec16 = (hidden_dim % 16 == 0);
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
if (use_warp_kernel) {
|
||||
@@ -172,23 +185,43 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
|
||||
constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256
|
||||
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
|
||||
dim3 block(THREADS);
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
}
|
||||
} else {
|
||||
// -------- baseline -----------------------------------------------------
|
||||
constexpr int THREADS = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(THREADS);
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
|
||||
static_cast<const scalar_t*>(input.data_ptr()),
|
||||
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
hidden_dim,
|
||||
num_tokens);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user