Minor improvement to per_tensor_quant_fp8 (#4197)
This commit is contained in:
@@ -57,13 +57,9 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
|
||||
|
||||
template <typename T>
|
||||
__global__ void per_tensor_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
FP8_TYPE* __restrict__ output,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t num_elements) {
|
||||
const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) {
|
||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int grid_size = blockDim.x * gridDim.x;
|
||||
const float scale_val = 1.0f / (*scale);
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
@@ -125,12 +121,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
|
||||
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
|
||||
}
|
||||
|
||||
float scale_val = 1.0f / (*static_cast<float*>(output_s.data_ptr()));
|
||||
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<scalar_t*>(input.data_ptr()),
|
||||
static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||
static_cast<float*>(output_s.data_ptr()),
|
||||
num_elements);
|
||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), scale_val, num_elements);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user