diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py index c56df30f5..0d90b51b3 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py @@ -1,13 +1,12 @@ import itertools -import math -from typing import Any, Dict, List, Optional, Tuple +from typing import Tuple import torch import triton import triton.language as tl from sgl_kernel import sgl_per_token_group_quant_fp8 -from sglang.srt.utils import get_device_core_count, get_device_name, is_hip +from sglang.srt.utils import is_hip is_hip_ = is_hip() fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn diff --git a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py index 19055d2c5..ed0bfc78b 100644 --- a/sgl-kernel/benchmark/bench_per_token_quant_fp8.py +++ b/sgl-kernel/benchmark/bench_per_token_quant_fp8.py @@ -40,9 +40,6 @@ def calculate_diff(batch_size: int, seq_len: int): scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item() output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() - print(f"Scale difference: {scale_diff}") - print(f"Output difference: {output_diff}") - if torch.allclose( vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index d9cabd783..d9290fe01 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -7,38 +7,6 @@ #include "utils.h" -#define WARP_SIZE 32 - -#ifndef USE_ROCM -#include -using FP8_TYPE = c10::Float8_e4m3fn; -C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); -#else -#include - -#include "amd/quant_utils.cuh" -using FP8_TYPE = c10::Float8_e4m3fnuz; -// Using the default max value from pytorch (240.0) will cause accuracy -// issue when running dynamic quantization. Here use 224.0f for rocm. -constexpr auto FP8_E4M3_MAX = 224.0f; -#endif - -__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { - float old; - old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) - : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); - return old; -} - -__device__ __forceinline__ float warpReduceMax(float max_value) { - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); - return max_value; -} - template __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) { diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index be272e065..5528ad8c5 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -1,5 +1,4 @@ #include -#include #include #include @@ -7,31 +6,6 @@ #include "utils.h" -#define WARP_SIZE 32 - -#ifndef USE_ROCM -#include -using FP8_TYPE = c10::Float8_e4m3fn; -C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); -#else -#include - -#include "amd/quant_utils.cuh" -using FP8_TYPE = c10::Float8_e4m3fnuz; -// Using the default max value from pytorch (240.0) will cause accuracy -// issue when running dynamic quantization. Here use 224.0f for rocm. -constexpr auto FP8_E4M3_MAX = 224.0f; -#endif - -__device__ __forceinline__ float warpReduceMax(float max_value) { - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); - max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); - return max_value; -} - template __global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, float* __restrict__ output_s, const int64_t hidden_dim, diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index a342dee10..3f574c954 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -95,3 +95,33 @@ inline int getSMVersion() { AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y)-1) / (y)) + +#define WARP_SIZE 32 + +#ifndef USE_ROCM +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#include + +#include "amd/quant_utils.cuh" +using FP8_TYPE = c10::Float8_e4m3fnuz; +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +} + +__device__ __forceinline__ float warpReduceMax(float max_value) { + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); + return max_value; +}