From e9f8e423189070d3b223457303ddce8ccb9ce1e7 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 24 Mar 2025 19:50:23 -0700 Subject: [PATCH] Support FP4 gemm (1/2) (#3899) --- sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu | 29 ++ sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu | 394 ++++++++++++++++++ sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu | 39 ++ .../csrc/gemm/nvfp4_scaled_mm_kernels.cu | 365 ++++++++++++++++ sgl-kernel/csrc/torch_extension.cc | 11 + sgl-kernel/include/sgl_kernel_ops.h | 9 + sgl-kernel/python/sgl_kernel/__init__.py | 2 + sgl-kernel/python/sgl_kernel/gemm.py | 72 +++- sgl-kernel/setup.py | 14 +- sgl-kernel/tests/test_fp4_gemm.py | 151 +++++++ sgl-kernel/tests/test_fp4_quantize.py | 164 ++++++++ 11 files changed, 1245 insertions(+), 5 deletions(-) create mode 100644 sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu create mode 100644 sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu create mode 100644 sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu create mode 100644 sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu create mode 100644 sgl-kernel/tests/test_fp4_gemm.py create mode 100644 sgl-kernel/tests/test_fp4_quantize.py diff --git a/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu b/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu new file mode 100644 index 000000000..60fda7dce --- /dev/null +++ b/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu @@ -0,0 +1,29 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void scaled_fp4_quant_sm100a( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf); +#endif + +void scaled_fp4_quant( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); +} diff --git a/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu new file mode 100644 index 000000000..fa96442df --- /dev/null +++ b/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu @@ -0,0 +1,394 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + // PTX instructions used here requires sm100a. +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), + "f"(array[1]), + "f"(array[2]), + "f"(array[3]), + "f"(array[4]), + "f"(array[5]), + "f"(array[6]), + "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + // PTX instructions used here requires sm100a. +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmp); + fp8SFVal = tmp.__x; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + fp8SFVal = tmp.__x; + SFValue = static_cast(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset(rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void invokeFP4Quantization( + int m, + int n, + T const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream) { + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // Launch the cvt kernel. + if (useUE8M0) { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput)); + } else { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput)); + } +} + +// Instantiate the function. +template void invokeFP4Quantization( + int m, + int n, + half const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeFP4Quantization( + int m, + int n, + __nv_bfloat16 const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +inline int getMultiProcessorCount() { + static int multi_processor_count = []() { + int device_id = 0; + int count = 0; + + // Get the current CUDA device ID + CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id)); + + // Get the number of multiprocessors for the current device + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id)); + + return count; // Initialize the static variable + }(); + + return multi_processor_count; // Return the cached value on subsequent calls +} + +void scaled_fp4_quant_sm100a( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + + int multiProcessorCount = getMultiProcessorCount(); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + // We don't support e8m0 scales at this moment. + bool useUE8M0 = false; + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream); + break; + } + default: { + std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid"; + throw std::runtime_error("Unsupported input data type for quantize_to_fp4."); + } + } +} diff --git a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu new file mode 100644 index 000000000..8bcd8c52b --- /dev/null +++ b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_entry.cu @@ -0,0 +1,39 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void cutlass_scaled_fp4_mm_sm100a( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha); +#endif + +void cutlass_scaled_fp4_mm( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel."); +} diff --git a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu new file mode 100644 index 000000000..c3da779f6 --- /dev/null +++ b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu @@ -0,0 +1,365 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +// Kernel Perf config +template +struct KernelTraits; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; +}; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_4, _4, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_4, _4, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +template +struct Fp4GemmSm100 { + // A matrix configuration + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + // B matrix configuration + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + // C/D matrix configuration + using ElementD = T; + using ElementC = T; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Kernel functional config + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + // Kernel Perf config + using MmaTileShape = typename KernelTraits::MmaTileShape; + using ClusterShape = typename KernelTraits::ClusterShape; + using PerSmTileShape_MNK = typename KernelTraits::PerSmTileShape_MNK; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + PerSmTileShape_MNK, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutCTag, + AlignmentC, + ElementD, + LayoutDTag, + AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + AlignmentA, + ElementB, + LayoutBTag, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); +}; + +template +typename T::Gemm::Arguments args_from_options( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + int64_t M, + int64_t N, + int64_t K) { + using ElementA = typename T::Gemm::ElementA; + using ElementB = typename T::Gemm::ElementB; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementD = typename T::Gemm::ElementD; + using ElementCompute = float; + using StrideA = typename T::StrideA; + using StrideB = typename T::StrideB; + using StrideD = typename T::StrideD; + using Sm100BlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + typename T::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + return arguments; +} + +template +void runGemm( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename Fp4GemmSm100::Gemm gemm; + + auto arguments = args_from_options>(D, A, B, A_sf, B_sf, alpha, m, n, k); + + size_t workspace_size = Fp4GemmSm100::Gemm::get_workspace_size(arguments); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} +#else +template +void runGemm( + at::Tensor& D, + at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + at::Tensor const& alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + TORCH_CHECK( + false, + "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " + "a CUTLASS 3.8 source directory to enable support."); +} +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +void cutlass_scaled_fp4_mm_sm100a( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha) { + CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); + + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); + + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + TORCH_CHECK( + A.sizes()[1] == B.sizes()[1], + "a and b shapes cannot be multiplied (", + A.sizes()[0], + "x", + A.sizes()[1], + " and ", + B.sizes()[0], + "x", + B.sizes()[1], + ")"); + + auto const m = A.sizes()[0]; + auto const n = B.sizes()[0]; + auto const k = A.sizes()[1] * 2; + + constexpr int alignment = 32; + TORCH_CHECK( + k % alignment == 0, + "Expected k to be divisible by ", + alignment, + ", but got a shape: (", + A.sizes()[0], + "x", + A.sizes()[1], + "), k: ", + k, + "."); + TORCH_CHECK( + n % alignment == 0, + "Expected n to be divisible by ", + alignment, + ", but got b shape: (", + B.sizes()[0], + "x", + B.sizes()[1], + ")."); + + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; + int rounded_m = round_up(m, 128); + int rounded_n = round_up(n, 128); + // Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an + // integer. + int rounded_k = round_up(k / 16, 4); + + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + TORCH_CHECK( + A_sf.sizes()[1] == B_sf.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + " and ", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + TORCH_CHECK( + A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", + rounded_m, + "x", + rounded_k, + "), but got a shape (", + A_sf.sizes()[0], + "x", + A_sf.sizes()[1], + ")"); + TORCH_CHECK( + B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", + rounded_n, + "x", + rounded_k, + "), but got a shape (", + B_sf.sizes()[0], + "x", + B_sf.sizes()[1], + ")"); + + auto out_dtype = D.dtype(); + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + + if (out_dtype == at::ScalarType::Half) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (out_dtype == at::ScalarType::BFloat16) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (out_dtype == at::ScalarType::Float) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); + } +} diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index fe5e09734..80c5c73d1 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -114,6 +114,17 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm); + m.def( + "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," + " Tensor block_scale_a, Tensor block_scale_b," + " Tensor alpha) -> ()"); + m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + + m.def( + "scaled_fp4_quant(Tensor! output, Tensor! input," + " Tensor! output_scale, Tensor! input_scale) -> ()"); + m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + /* * From csrc/moe */ diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 98bd7bdac..36921a29f 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -113,6 +113,13 @@ void apply_rope_pos_ids_cos_sin_cache( * From csrc/gemm */ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros); +void cutlass_scaled_fp4_mm( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha); torch::Tensor int8_scaled_mm( const torch::Tensor& mat_a, const torch::Tensor& mat_b, @@ -133,6 +140,8 @@ torch::Tensor fp8_blockwise_scaled_mm( const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype); +void scaled_fp4_quant( + torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); void sgl_per_token_group_quant_fp8( at::Tensor input, at::Tensor output_q, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index fbc4ac675..1a668fee4 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -26,9 +26,11 @@ from sgl_kernel.gemm import ( awq_dequantize, bmm_fp8, cublas_grouped_gemm, + cutlass_scaled_fp4_mm, fp8_blockwise_scaled_mm, fp8_scaled_mm, int8_scaled_mm, + scaled_fp4_quant, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index bab9e3c8c..f63360722 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple import torch from sgl_kernel.utils import _get_cache_buf, get_cuda_stream @@ -145,3 +145,73 @@ def sgl_per_token_quant_fp8( output_s: torch.Tensor, ) -> None: torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s) + + +def cutlass_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + m, n = a.shape[0], b.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + torch.ops.sgl_kernels.cutlass_scaled_fp4_mm( + out, a, b, block_scale_a, block_scale_b, alpha + ) + return out + + +def scaled_fp4_quant( + input: torch.Tensor, input_global_scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale. + + This function quantizes the last dimension of the given tensor `input`. For + every 16 consecutive elements, a single dynamically computed scaling factor + is shared. This scaling factor is quantized using the `input_global_scale` + and is stored in a swizzled layout (see + https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). + + Args: + input: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every + two values are packed into a uint8 and float8_e4m3 scaling factors + in a sizzled layout. + """ + assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." + other_dims = 1 if input.ndim == 1 else -1 + input = input.reshape(other_dims, input.shape[-1]) + m, n = input.shape + block_size = 16 + device = input.device + + assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." + assert input.dtype in ( + torch.float16, + torch.bfloat16, + ), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." + + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Then, the scaling + # factors in float8_e4m3fn are packed into an int32 for every 4 values. + rounded_m = ((m + 128 - 1) // 128) * 128 + scale_n = n // block_size + rounded_n = ((scale_n + 4 - 1) // 4) * 4 + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + + torch.ops.sgl_kernels.scaled_fp4_quant( + output, input, output_scale, input_global_scale + ) + output_scale = output_scale.view(torch.float8_e4m3fn) + return output, output_scale diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index a7b612220..7b21ba593 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -153,6 +153,10 @@ sources = [ "csrc/gemm/fp8_gemm_kernel.cu", "csrc/gemm/fp8_blockwise_gemm_kernel.cu", "csrc/gemm/int8_gemm_kernel.cu", + "csrc/gemm/nvfp4_quant_entry.cu", + "csrc/gemm/nvfp4_quant_kernels.cu", + "csrc/gemm/nvfp4_scaled_mm_entry.cu", + "csrc/gemm/nvfp4_scaled_mm_kernels.cu", "csrc/gemm/per_token_group_quant_8bit.cu", "csrc/gemm/per_token_quant_fp8.cu", "csrc/gemm/per_tensor_quant_fp8.cu", @@ -169,6 +173,7 @@ sources = [ enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1" +enable_fp4 = os.getenv("SGL_KERNEL_ENABLE_FP4", "0") == "1" enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1" enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1" cuda_version = _get_cuda_version() @@ -180,6 +185,7 @@ if torch.cuda.is_available(): if cuda_version >= (12, 8) and sm_version >= 100: nvcc_flags.append("-gencode=arch=compute_100,code=sm_100") nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a") + nvcc_flags.append("-DENABLE_NVFP4=1") else: nvcc_flags.append("-use_fast_math") if sm_version >= 90: @@ -188,12 +194,12 @@ if torch.cuda.is_available(): nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") else: # compilation environment without GPU - if enable_sm90a: - nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") if enable_sm100a: nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a") - else: - nvcc_flags.append("-use_fast_math") + if enable_sm90a: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + if enable_fp4: + nvcc_flags.append("-DENABLE_NVFP4=1") if enable_fp8: nvcc_flags.extend(nvcc_flags_fp8) if enable_bf16: diff --git a/sgl-kernel/tests/test_fp4_gemm.py b/sgl-kernel/tests/test_fp4_gemm.py new file mode 100644 index 000000000..5c092bd13 --- /dev/null +++ b/sgl-kernel/tests/test_fp4_gemm.py @@ -0,0 +1,151 @@ +import pytest +import torch +from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant + +if torch.cuda.get_device_capability() < (10, 0): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +kE2M1ToFloatArray = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] + + +def e2m1_to_fp32(int4_value): + signBit = int4_value & 0x8 + int4_absValue = int4_value & 0x7 + float_result = kE2M1ToFloatArray[int4_absValue] + if signBit: + float_result = -float_result + return float_result + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + a = a.flatten() + # Get upper 4 bits + highHalfByte = (a & 0xF0) >> 4 + # Get lower 4 bits + lowHalfByte = a & 0x0F + fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device) + fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device) + # [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC] + out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2) + return out + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + sf_m, sf_k = a_sf_swizzled.shape + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype( + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 +): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out + + +def get_ref_results( + a_fp4, + b_fp4, + a_sf, + b_sf, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, +): + _, m_k = a_fp4.shape + _, n_k = b_fp4.shape + assert m_k == n_k + a_in_dtype = dequantize_to_dtype( + a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size + ) + b_in_dtype = dequantize_to_dtype( + b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size + ) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@torch.inference_mode() +def test_nvfp4_gemm( + dtype: torch.dtype, + shape: tuple[int, int], +) -> None: + m, n, packed_k = shape + k = packed_k * 2 + block_size = 16 + a_dtype = torch.randn((m, k), dtype=dtype, device="cuda") + b_dtype = torch.randn((n, k), dtype=dtype, device="cuda") + + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) + ).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) + a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale) + b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale) + + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + "cuda", + ) + out = cutlass_scaled_fp4_mm( + a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype + ) + + torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) diff --git a/sgl-kernel/tests/test_fp4_quantize.py b/sgl-kernel/tests/test_fp4_quantize.py new file mode 100644 index 000000000..6b2489314 --- /dev/null +++ b/sgl-kernel/tests/test_fp4_quantize.py @@ -0,0 +1,164 @@ +import pytest +import torch +from sgl_kernel import scaled_fp4_quant + +if torch.cuda.get_device_capability() < (10, 0): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] +PAD_SHAPES = [ + (90, 64), + (150, 64), + (128, 48), + (128, 80), + (150, 80), + (90, 48), + (90, 128), + (150, 128), + (150, 48), + (90, 80), +] + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +# E2M1 to float +# 0111 -> 6 +# 0110 -> 4 +# 0101 -> 3 +# 0100 -> 2 +# 0011 -> 1.5 +# 0010 -> 1 +# 0001 -> 0.5 +# 0000 -> 0 +E2M1_TO_FLOAT32 = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] +BLOCK_SIZE = 16 + + +def cast_from_fp4(x, m, n): + # The fp4 values are packed in uint8 as [v_1st | v_2nd] + v_2nd = x & 0xF + v_1st = (x >> 4) & 0xF + c = torch.stack((v_2nd, v_1st), dim=-1) + out = torch.tensor([E2M1_TO_FLOAT32[x] for x in c.flatten()]) + out = out.reshape(m, n).to(torch.float32) + return out + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def ref_nvfp4_quant(x, global_scale): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE)) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + return cast_to_fp4(clipped_x), scale.squeeze(-1) + + +def recover_swizzled_scales(scale, m, n): + rounded_m = ((m + 128 - 1) // 128) * 128 + scale_n = n // BLOCK_SIZE + rounded_n = ((scale_n + 4 - 1) // 4) * 4 + # Recover the swizzled scaling factor to linear layout + tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32) + return result[:m, :scale_n] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@torch.inference_mode() +def test_quantize_to_fp4( + dtype: torch.dtype, + shape: tuple[int, int], +) -> None: + torch.manual_seed(42) + torch.set_default_device("cuda:0") + + m, n = shape + + x = torch.randn((m, n), dtype=dtype) + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = scaled_fp4_quant(x, global_scale) + scale_ans = recover_swizzled_scales(out_scale, m, n) + out_ans = cast_from_fp4(out, m, n) + + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref) + + +@pytest.mark.parametrize("pad_shape", PAD_SHAPES) +@torch.inference_mode() +def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: + torch.manual_seed(42) + dtype = torch.float16 + torch.set_default_device("cuda:0") + + m, n = pad_shape + + x = torch.randn((m, n), dtype=dtype) + + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = scaled_fp4_quant(x, global_scale) + + scale_ans = recover_swizzled_scales(out_scale, m, n) + out_ans = cast_from_fp4(out, m, n) + + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref)