diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index f1f7d14a9..ea222c001 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -57,9 +57,13 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output template __global__ void per_tensor_quant_fp8_kernel( - const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) { + const T* __restrict__ input, + FP8_TYPE* __restrict__ output, + const float* __restrict__ scale, + const int64_t num_elements) { const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int grid_size = blockDim.x * gridDim.x; + const float scale_val = 1.0f / (*scale); constexpr uint32_t vec_size = 16 / sizeof(T); using vec_t = flashinfer::vec_t; @@ -121,9 +125,12 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch per_tensor_absmax_kernel<<>>( static_cast(input.data_ptr()), static_cast(output_s.data_ptr()), num_elements); } - float scale_val = 1.0f / (*static_cast(output_s.data_ptr())); + per_tensor_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), scale_val, num_elements); + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + num_elements); return true; }); }