From 90bb2be27e498be472af40f5ace8b2d9cd817d1d Mon Sep 17 00:00:00 2001 From: Rex Date: Fri, 7 Mar 2025 22:52:12 -0800 Subject: [PATCH] Minor improvement to per_tensor_quant_fp8 (#4197) --- .../sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index ea222c001..f1f7d14a9 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -57,13 +57,9 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output template __global__ void per_tensor_quant_fp8_kernel( - const T* __restrict__ input, - FP8_TYPE* __restrict__ output, - const float* __restrict__ scale, - const int64_t num_elements) { + const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) { const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int grid_size = blockDim.x * gridDim.x; - const float scale_val = 1.0f / (*scale); constexpr uint32_t vec_size = 16 / sizeof(T); using vec_t = flashinfer::vec_t; @@ -125,12 +121,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch per_tensor_absmax_kernel<<>>( static_cast(input.data_ptr()), static_cast(output_s.data_ptr()), num_elements); } - + float scale_val = 1.0f / (*static_cast(output_s.data_ptr())); per_tensor_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), - static_cast(output_q.data_ptr()), - static_cast(output_s.data_ptr()), - num_elements); + static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), scale_val, num_elements); return true; }); }