Revert "Minor improvement to per_tensor_quant_fp8 (#4197)" (#4198)

This commit is contained in:
Yineng Zhang
2025-03-07 22:57:09 -08:00
committed by GitHub
parent 90bb2be27e
commit 96d0e37fa7

View File

@@ -57,9 +57,13 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
template <typename T>
__global__ void per_tensor_quant_fp8_kernel(
const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) {
const T* __restrict__ input,
FP8_TYPE* __restrict__ output,
const float* __restrict__ scale,
const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x;
const float scale_val = 1.0f / (*scale);
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
@@ -121,9 +125,12 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
}
float scale_val = 1.0f / (*static_cast<float*>(output_s.data_ptr()));
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), scale_val, num_elements);
static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
num_elements);
return true;
});
}