diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index ac11ff2a7..c204dc151 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -157,6 +157,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "Tensor output_scale_offset_by_experts) -> ()"); m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); + m.def( + "silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts, Tensor mask) -> ()"); + m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant); + m.def( "cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b," "Tensor a_blockscale, Tensor b_blockscale, Tensor alphas," diff --git a/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu b/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu index af52196f6..3f996f668 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu @@ -239,6 +239,33 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, #endif } +__device__ __forceinline__ float silu(const float& val) { + return val / (1.0f + __expf(-val)); +} + +template +inline __device__ void silu_and_mul(PackedVec& x_vec, const PackedVec& y_vec) { + float2 x[CVT_FP4_ELTS_PER_THREAD / 2]; + float2 y[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + x[i] = __half22float2(x_vec.elts[i]); + y[i] = __half22float2(y_vec.elts[i]); + x[i].x = silu(x[i].x) * y[i].x; + x[i].y = silu(x[i].y) * y[i].y; + x_vec.elts[i] = __float22half2_rn(x[i]); + } else { + x[i] = __bfloat1622float2(x_vec.elts[i]); + y[i] = __bfloat1622float2(y_vec.elts[i]); + x[i].x = silu(x[i].x) * y[i].x; + x[i].y = silu(x[i].y) * y[i].y; + x_vec.elts[i] = __float22bfloat162_rn(x[i]); + } + } +} + // Use UE4M3 by default. template __global__ void @@ -255,6 +282,7 @@ cvt_fp16_to_fp4( uint32_t* SFout, uint32_t* input_offset_by_experts, uint32_t* output_scale_offset_by_experts, + int32_t* mask, int n_experts, bool low_latency) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -265,6 +293,11 @@ cvt_fp16_to_fp4( // Input tensor row/col loops. int tid = blockIdx.x * blockDim.x + threadIdx.x; int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // TODO(kaixih@nvidia): For now, we assume mask is used together with + // silu_and_mal. Maybe we want a more general behavior of mask later. In the + // silu case, the input last dim doubles. + bool use_mask = mask != nullptr; + int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow; // Each global thread processes one element for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) { @@ -272,13 +305,6 @@ cvt_fp16_to_fp4( int rowIdx = globalIdx / colsPerRow; int colIdx = globalIdx % colsPerRow; - int64_t inOffset = rowIdx * colsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; - // Find index within the experts using different strategies based on expert // count int rowIdx_in_expert = 0; @@ -321,6 +347,23 @@ cvt_fp16_to_fp4( } } + // Eerly exit when using masks. + if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { + continue; + } + + int64_t inOffset = rowIdx * actualColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + if (use_mask) { + PackedVec in_vec_mul = reinterpret_cast(in)[inOffset + colsPerRow]; + silu_and_mul(in_vec, in_vec_mul); + } + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). @@ -356,6 +399,7 @@ cvt_fp16_to_fp4( uint32_t* SFout, uint32_t* input_offset_by_experts, uint32_t* output_scale_offset_by_experts, + int32_t* mask, int n_experts) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using PackedVec = PackedVec; @@ -383,6 +427,8 @@ cvt_fp16_to_fp4( int tid = blockIdx.x * blockDim.x + threadIdx.x; int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + bool use_mask = mask != nullptr; + int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow; // Each global thread processes one element for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) { @@ -390,11 +436,6 @@ cvt_fp16_to_fp4( int rowIdx = globalIdx / colsPerRow; int colIdx = globalIdx % colsPerRow; - int64_t inOffset = rowIdx * colsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; - // Find expert using binary search for better performance with large m_topk int rowIdx_in_expert = 0; int expert_idx = 0; @@ -419,6 +460,21 @@ cvt_fp16_to_fp4( } } + if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { + continue; + } + + int64_t inOffset = rowIdx * actualColsPerRow + colIdx; + + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + if (use_mask) { + PackedVec in_vec_mul = reinterpret_cast(in)[inOffset + colsPerRow]; + silu_and_mul(in_vec, in_vec_mul); + } + + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; int factor = CVT_FP4_SF_VEC_SIZE * 4; @@ -442,6 +498,7 @@ void quant_impl( void* input_global_scale, void* input_offset_by_experts, void* output_scale_offset_by_experts, + void* mask, int m_topk, int k, int n_experts, @@ -478,6 +535,7 @@ void quant_impl( reinterpret_cast(output_scale), reinterpret_cast(input_offset_by_experts), reinterpret_cast(output_scale_offset_by_experts), + reinterpret_cast(mask), n_experts); } else { cvt_fp16_to_fp4<<>>( @@ -489,6 +547,7 @@ void quant_impl( reinterpret_cast(output_scale), reinterpret_cast(input_offset_by_experts), reinterpret_cast(output_scale_offset_by_experts), + reinterpret_cast(mask), n_experts); } } else { @@ -502,6 +561,7 @@ void quant_impl( reinterpret_cast(output_scale), reinterpret_cast(input_offset_by_experts), reinterpret_cast(output_scale_offset_by_experts), + reinterpret_cast(mask), n_experts, /* bool low_latency */ true); } else { @@ -514,6 +574,7 @@ void quant_impl( reinterpret_cast(output_scale), reinterpret_cast(input_offset_by_experts), reinterpret_cast(output_scale_offset_by_experts), + reinterpret_cast(mask), n_experts, /* bool low_latency */ true); } @@ -590,6 +651,7 @@ void scaled_fp4_experts_quant_sm100a( input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), output_scale_offset_by_experts.data_ptr(), + nullptr, // mask m_topk, k, n_experts, @@ -602,6 +664,92 @@ void scaled_fp4_experts_quant_sm100a( input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), output_scale_offset_by_experts.data_ptr(), + nullptr, // mask + m_topk, + k, + n_experts, + stream); + } else { + TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); + } +} + +void silu_and_mul_scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor const& input, + torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts, + torch::Tensor const& mask) { + CHECK_INPUT(output, "output must be a CUDA tensor"); + CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); + CHECK_INPUT(input, "input must be a CUDA tensor"); + CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); + CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor"); + CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor"); + CHECK_INPUT(mask, "mask must be a CUDA tensor"); + + TORCH_CHECK(output.dim() == 2); + TORCH_CHECK(output_scale.dim() == 2); + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(input_global_scale.dim() == 1); + TORCH_CHECK(input_offset_by_experts.dim() == 1); + TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + + TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(mask.scalar_type() == INT); + // output is uint8 (two nvfp4 values are packed into one uint8) + // output_scale is int32 (four fp8 values are packed into one int32) + TORCH_CHECK(output.scalar_type() == UINT8); + TORCH_CHECK(output_scale.scalar_type() == INT); + + const int BLOCK_SIZE = 16; + auto m_topk = input.size(0); + auto k_by_2 = input.size(1); + TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2"); + auto k = k_by_2 / 2; + TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto n_experts = input_global_scale.size(0); + TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(mask.size(0) == n_experts); + TORCH_CHECK(output.size(0) == m_topk); + TORCH_CHECK(output.size(1) == k / 2); + int scales_k = k / BLOCK_SIZE; + // 4 means the swizzle requirement by nvidia nvfp4. + int padded_k = (scales_k + (4 - 1)) / 4 * 4; + // 4 means 4 fp8 values are packed into one int32 + TORCH_CHECK(output_scale.size(1) * 4 == padded_k); + + auto in_dtype = input.dtype(); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + if (in_dtype == at::ScalarType::Half) { + quant_impl( + output.data_ptr(), + output_scale.data_ptr(), + input.data_ptr(), + input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), + mask.data_ptr(), + m_topk, + k, + n_experts, + stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + quant_impl<__nv_bfloat16>( + output.data_ptr(), + output_scale.data_ptr(), + input.data_ptr(), + input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), + mask.data_ptr(), m_topk, k, n_experts, diff --git a/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu b/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu index 8b6a0a275..335fd512a 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu @@ -27,6 +27,15 @@ void scaled_fp4_experts_quant_sm100a( torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); +void silu_and_mul_scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor const& input, + torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts, + torch::Tensor const& mask); + #endif void scaled_fp4_quant( @@ -50,3 +59,18 @@ void scaled_fp4_experts_quant( #endif TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); } + +void silu_and_mul_scaled_fp4_experts_quant( + torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor const& input, + torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts, + torch::Tensor const& mask) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return silu_and_mul_scaled_fp4_experts_quant_sm100a( + output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts, mask); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 33d883d2c..5765a0b7e 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -389,6 +389,14 @@ void scaled_fp4_experts_quant( torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); +void silu_and_mul_scaled_fp4_experts_quant( + torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor const& input, + torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts, + torch::Tensor const& mask); /* * From csrc/moe/cutlass_moe/w4a8 */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 6480a097d..05a62efaa 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -52,12 +52,14 @@ from sgl_kernel.gemm import ( qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm, scaled_fp4_experts_quant, + scaled_fp4_grouped_quant, scaled_fp4_quant, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, shuffle_rows, + silu_and_mul_scaled_fp4_grouped_quant, ) from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda from sgl_kernel.kvcacheio import ( diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index dafc739a1..bd85ee949 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -295,6 +295,142 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): return output_tensor +def scaled_fp4_grouped_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, +): + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer). + Args: + input: The input tensor to be quantized to FP4, with shape (l, m, k) + l is number of groups, m is number of tokens per group, k is number of features. + input_global_scale: A scalar scaling factor for the entire tensor, with + shape (l,). + Outputs: + output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical + layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into + an uint8. + output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l) + but the physical layout is (l, rm, rk, 32, 4, 4). + Note: + For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128. + `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are + required by the NVIDIA Blackwell MMA operations. + """ + device = input_tensor.device + l, m, k = input_tensor.shape + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_k_int32 = padded_k // 4 + padded_m = (m + (128 - 1)) // 128 * 128 + output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output_scales = torch.empty( + l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + ) + input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device) + output_offsets = torch.arange( + 0, + (l + 1) * padded_m, + step=padded_m, + dtype=torch.int, + device=device, + ) + + torch.ops.sgl_kernel.scaled_fp4_experts_quant.default( + output.view(l * m, k // 2), + output_scales.view(l * padded_m, padded_k_int32), + input_tensor.view(l * m, k), + input_global_scale, + input_offsets, + output_offsets, + ) + # The physical layout of the output is (l, m, k // 2), but we want to return a + # logical layout (m, k // 2, l) required by the flashinfer masked group gemm. + output = output.permute(1, 2, 0) + # The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a + # requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic + # layout is (32, 4, rm, 4, rk, l). + output_scales = output_scales.view(torch.float8_e4m3fn).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + return output, output_scales + + +def silu_and_mul_scaled_fp4_grouped_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + mask: torch.Tensor, +): + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer). + Args: + input: The input tensor to be quantized to FP4, with shape (l, m, k * 2) + l is number of groups, m is number of tokens per group, k is number of features. + input_global_scale: A scalar scaling factor for the entire tensor, with + shape (l,). + mask: The mask tensor, with shape (l,) + Outputs: + output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical + layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into + an uint8. + output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l) + but the physical layout is (l, rm, rk, 32, 4, 4). + Note: + For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128. + `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are + required by the NVIDIA Blackwell MMA operations. + """ + device = input_tensor.device + l, m, k_by_2 = input_tensor.shape + k = k_by_2 // 2 + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_k_int32 = padded_k // 4 + padded_m = (m + (128 - 1)) // 128 * 128 + output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output_scales = torch.empty( + l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + ) + input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device) + output_offsets = torch.arange( + 0, + (l + 1) * padded_m, + step=padded_m, + dtype=torch.int, + device=device, + ) + + torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default( + output.view(l * m, k // 2), + output_scales.view(l * padded_m, padded_k_int32), + input_tensor.view(l * m, k_by_2), + input_global_scale, + input_offsets, + output_offsets, + mask, + ) + # The physical layout of the output is (l, m, k // 2), but we want to return a + # logical layout (m, k // 2, l) required by the flashinfer masked group gemm. + output = output.permute(1, 2, 0) + # The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a + # requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic + # layout is (32, 4, rm, 4, rk, l). + output_scales = output_scales.view(torch.float8_e4m3fn).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + return output, output_scales + + def scaled_fp4_experts_quant( input_tensor: torch.Tensor, input_global_scale: torch.Tensor, diff --git a/sgl-kernel/tests/test_fp4_quantize.py b/sgl-kernel/tests/test_fp4_quantize.py index dcf09e053..6f68330cd 100644 --- a/sgl-kernel/tests/test_fp4_quantize.py +++ b/sgl-kernel/tests/test_fp4_quantize.py @@ -1,6 +1,11 @@ import pytest import torch -from sgl_kernel import scaled_fp4_quant +from sgl_kernel import ( + scaled_fp4_grouped_quant, + scaled_fp4_quant, + silu_and_mul, + silu_and_mul_scaled_fp4_grouped_quant, +) skip_condition = torch.cuda.get_device_capability() < (10, 0) @@ -166,5 +171,83 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: torch.testing.assert_close(scale_ans, scale_ref) +@pytest.mark.skipif( + skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." +) +def test_quantize_to_fp4_grouped(): + torch.manual_seed(42) + torch.set_default_device("cuda:0") + + l, m, k = 2, 512, 2048 + x = torch.randn((l, m, k), dtype=torch.bfloat16) + tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32) + x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + output, output_scales = scaled_fp4_grouped_quant( + x, + x_sf_global, + ) + # output in logical (m, k, l), but its physical layout is (l, m, k). + # So permute first to (l, m, k). + output = output.permute(2, 0, 1) + # output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4). + # So permute first to (l, rm, rk, 32, 4, 4). + padded_m = ((m + 128 - 1) // 128) * 128 + output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1) + for i in range(l): + a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i]) + torch.testing.assert_close(a_fp4, output[i]) + torch.testing.assert_close( + a_scale_interleaved.to(torch.float), output_scales[i].to(torch.float) + ) + + +@pytest.mark.skipif( + skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." +) +@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048)]) +def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None: + torch.manual_seed(42) + torch.set_default_device("cuda:0") + + l, m, k = shape + x = torch.randn((l, m, k * 2), dtype=torch.bfloat16) + max_m = 8 + assert max_m <= m + mask = torch.randint(1, max_m, (l,), dtype=torch.int32) + + ref_y = silu_and_mul(x) + tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32) + y_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + ref_output, ref_output_scales = scaled_fp4_grouped_quant( + ref_y, + y_sf_global, + ) + output, output_scales = silu_and_mul_scaled_fp4_grouped_quant( + x, + y_sf_global, + mask, + ) + + # output in logical (m, k, l), but its physical layout is (l, m, k). + # So permute first to (l, m, k). + output = output.permute(2, 0, 1) + ref_output = ref_output.permute(2, 0, 1) + + # output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4). + # So permute first to (l, rm, rk, 32, 4, 4). + padded_m = ((m + 128 - 1) // 128) * 128 + output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1) + ref_output_scales = ref_output_scales.permute(5, 2, 4, 0, 1, 3).view( + l, padded_m, -1 + ) + + for i in range(l): + torch.testing.assert_close(ref_output[i, : mask[i]], output[i, : mask[i]]) + # We need to recover the swizzled scales to linear layout before applying mask slice. + scale_ref = recover_swizzled_scales(ref_output_scales[i], m, k) + scale_ans = recover_swizzled_scales(output_scales[i], m, k) + torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]]) + + if __name__ == "__main__": pytest.main([__file__])