From 96d0e37fa7621c37a130ec12f867c8f99c9ef878 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 7 Mar 2025 22:57:09 -0800 Subject: [PATCH] Revert "Minor improvement to per_tensor_quant_fp8 (#4197)" (#4198) --- .../sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index f1f7d14a9..ea222c001 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -57,9 +57,13 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output template __global__ void per_tensor_quant_fp8_kernel( - const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) { + const T* __restrict__ input, + FP8_TYPE* __restrict__ output, + const float* __restrict__ scale, + const int64_t num_elements) { const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int grid_size = blockDim.x * gridDim.x; + const float scale_val = 1.0f / (*scale); constexpr uint32_t vec_size = 16 / sizeof(T); using vec_t = flashinfer::vec_t; @@ -121,9 +125,12 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch per_tensor_absmax_kernel<<>>( static_cast(input.data_ptr()), static_cast(output_s.data_ptr()), num_elements); } - float scale_val = 1.0f / (*static_cast(output_s.data_ptr())); + per_tensor_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), scale_val, num_elements); + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + num_elements); return true; }); }