diff --git a/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu b/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu index 38d943653..1228b21c5 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu @@ -1,169 +1,11 @@ #include #include -#include #include +#include #include -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { - // PTX instructions used here requires sm100a. -#if CUDA_VERSION >= 12080 -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), - "f"(array[1]), - "f"(array[2]), - "f"(array[3]), - "f"(array[4]), - "f"(array[5]), - "f"(array[6]), - "f"(array[7])); - return val; -#else - return 0; -#endif -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { - // PTX instructions used here requires sm100a. -#if CUDA_VERSION >= 12080 -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), - "f"(array[0].y), - "f"(array[1].x), - "f"(array[1].y), - "f"(array[2].x), - "f"(array[2].y), - "f"(array[3].x), - "f"(array[3].y)); - return val; -#else - return 0; -#endif -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + - innerMIdx * innerMStride + innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; +#include "nvfp4_quant.cuh" +#include "utils.h" // Quantizes the provided PackedVec into the uint32_t output template @@ -720,6 +562,9 @@ void scaled_fp4_experts_quant_sm100a( torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts) { + auto sm_version = getSMVersion(); + TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a"); + CHECK_INPUT(output, "output must be a CUDA tensor"); CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); CHECK_INPUT(input, "input must be a CUDA tensor"); @@ -801,6 +646,9 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( torch::Tensor const& input_global_scale, torch::Tensor const& mask, bool use_silu_and_mul) { + auto sm_version = getSMVersion(); + TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a"); + CHECK_INPUT(output, "output must be a CUDA tensor"); CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); CHECK_INPUT(input, "input must be a CUDA tensor"); diff --git a/sgl-kernel/csrc/gemm/nvfp4_quant.cuh b/sgl-kernel/csrc/gemm/nvfp4_quant.cuh new file mode 100644 index 000000000..b2aa5f006 --- /dev/null +++ b/sgl-kernel/csrc/gemm/nvfp4_quant.cuh @@ -0,0 +1,176 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + // PTX instructions used here requires sm100a/sm103a. +#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), + "f"(array[1]), + "f"(array[2]), + "f"(array[3]), + "f"(array[4]), + "f"(array[5]), + "f"(array[6]), + "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + // PTX instructions used here requires sm100a/sm103a. +#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; diff --git a/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu index 5024d20af..d307f5fb7 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu @@ -15,176 +15,13 @@ limitations under the License. #include #include -#include -#include #include #include #include +#include "nvfp4_quant.cuh" #include "utils.h" -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { - // PTX instructions used here requires sm100a. -#if CUDA_VERSION >= 12080 -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), - "f"(array[1]), - "f"(array[2]), - "f"(array[3]), - "f"(array[4]), - "f"(array[5]), - "f"(array[6]), - "f"(array[7])); - return val; -#else - return 0; -#endif -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { - // PTX instructions used here requires sm100a. -#if CUDA_VERSION >= 12080 -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), - "f"(array[0].y), - "f"(array[1].x), - "f"(array[1].y), - "f"(array[2].x), - "f"(array[2].y), - "f"(array[3].x), - "f"(array[3].y)); - return val; -#else - return 0; -#endif -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + - innerMIdx * innerMStride + innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - // Quantizes the provided PackedVec into the uint32_t output template __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { @@ -364,6 +201,9 @@ inline int getMultiProcessorCount() { void scaled_fp4_quant_sm100a( torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { + auto sm_version = getSMVersion(); + TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a"); + int32_t m = input.size(0); int32_t n = input.size(1);