diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 9046fc676..199982c9b 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -1,4 +1,4 @@ -"""Cutlass MoE kernel.""" +"""CUTLASS based Fused MoE kernels.""" import functools import json @@ -14,8 +14,10 @@ _is_cuda = is_cuda() if _is_cuda: import sgl_kernel from sgl_kernel import ( + cutlass_fp4_group_mm, fp8_blockwise_scaled_grouped_mm, prepare_moe_input, + scaled_fp4_experts_quant, silu_and_mul, ) @@ -205,3 +207,178 @@ def cutlass_fused_experts( return ( c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) ).sum(dim=1) + + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = 448.0 + + +def cutlass_moe_fp4( + a: torch.Tensor, + a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + ab_strides_13: torch.Tensor, + ab_strides_2: torch.Tensor, + c_strides_13: torch.Tensor, + c_strides_2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, +): + """ + MoE implementation for FP4 Inputs + + # Gemm 1 + a: Input tensor: [m, k] (half/bfloat16) + a1_gscale: Activation scale per expert: [e] (float32) + w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] + w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) + (Note: `n` is the up projection output dim, `k` is the input dim in + full precision) + w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) + (Block size = 16 for NVFP4) + + # Gemm 2 + a2_gscale: Activation scale per expert: [e] + w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] + w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) + w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 + + Strides for activations, weights and output in logical number of elements. + The activations & output stride is the number of elements to the next row. + The weights stride is the number of elements to the next row per expert. + For example, if the weight is [e, n, k], then the b_stride is a tensor of + shape [e] with each element being k. Similarly for activations, if the + shape is [m, k], then the a_stride has shape [e] with each value k. + Similarly for output, if the output is [m, n], then the c_stride is a + tensor of shape [e] with each element being k. + + Note: cutlass_fp4_group_mm is designed to accept the strides of + activations and weights to be the same, so it is passed in as a single + tensor. + ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides] + ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides] + c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides] + c_strides_2: [e] dtype: int64 [Gemm 1: Output Strides] + + topk_weights: [m, topk] dtype: float8 + topk_ids: [m, topk] dtype: float8 + + m, n, k: Unquantized weight shapes, dtype: int + e: number of experts for the current rank, dtype: int + assumes that topk < k < n to satisfy - up/down projection expectations. + """ + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" + assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" + assert ( + w1_fp4.ndim == 3 + and w2_fp4.ndim == 3 + and w1_blockscale.ndim == 3 + and w2_blockscale.ndim == 3 + ), "All Weights must be of rank 3 for cutlass_moe_fp4" + m_a, k_a = a.shape + e_w1, nx2_w1, half_k_w1 = w1_fp4.shape + e_w2, k_w2, half_n_w2 = w2_fp4.shape + + assert e_w1 == e_w2 and e_w1 == e, ( + "Number of experts must match", + " between weights.", + ) + assert ( + k_a // 2 == half_k_w1 and k == k_w2 + ), "Hidden size mismatch between a, w1 and w2" + assert nx2_w1 == n * 2 and half_n_w2 == n // 2, "mismatch in " "expected `n`" + assert m == m_a, "input shape mismatch" + assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" + assert ( + topk_weights.shape[0] == m and topk_ids.shape[0] == m + ), "topk must be provided for each row of a" + + out_dtype = a.dtype + num_topk = topk_ids.shape[1] + + expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,2n,k)) + problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,n,k)) + problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + # problem shapes should have [m, n, k] + # Note that problem sizes are based on logical number of elements. + blockscale_offsets = torch.empty(e + 1, dtype=torch.int32, device=device) + prepare_moe_input( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + e, + n, + k, + blockscale_offsets, + ) + + rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant( + a, a1_gscale, expert_offsets, blockscale_offsets, num_topk, expert_map=a_map + ) + + c1 = cutlass_fp4_group_mm( + rep_a_fp4, + w1_fp4, + rep_a_blockscale, + w1_blockscale, + w1_alphas, + ab_strides_13, + c_strides_13, + problem_sizes1, + expert_offsets[:-1], + blockscale_offsets[:-1], + out_dtype, + device, + ) + del rep_a_fp4, rep_a_blockscale + # hidden size dimension is split to one halfpytho sized tensor. + intermediate = torch.empty( + (m * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype + ) + + silu_and_mul(c1, intermediate) + + int_fp4, int_blockscale = scaled_fp4_experts_quant( + intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk + ) + c2 = cutlass_fp4_group_mm( + int_fp4, + w2_fp4, + int_blockscale, + w2_blockscale, + w2_alphas, + ab_strides_2, + c_strides_2, + problem_sizes2, + expert_offsets[:-1], + blockscale_offsets[:-1], + out_dtype, + device, + ) + del int_fp4, int_blockscale + out = ( + c2[c_map].view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half() + ).sum(dim=1) + return out.to(dtype=out_dtype) diff --git a/python/sglang/test/test_fp4_moe.py b/python/sglang/test/test_fp4_moe.py new file mode 100644 index 000000000..df3d1f7c1 --- /dev/null +++ b/python/sglang/test/test_fp4_moe.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +from sgl_kernel import scaled_fp4_quant + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 +from sglang.srt.layers.moe.topk import select_experts + +if torch.cuda.get_device_capability() < (10, 0): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + +FLOAT8_E4M3_MAX = 448.0 +FLOAT4_E2M1_MAX = 6.0 + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_nvfp4_to_dtype( + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 +): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1024), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +# Reference implementation of torch_moe +def torch_moe(a, w1, w2, score, topk, expert_map): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + if expert_map is not None: + topk_ids = expert_map[topk_ids] + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( + 0, 1 + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_cutlass_fp4_moe_no_graph( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +): + + torch.manual_seed(7) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + quant_blocksize = 16 + round_up = lambda x, y: (x + y - 1) // y * y + sf_w1_2n = round_up(2 * n, 128) + sf_w1_k = round_up(k // quant_blocksize, 4) + w1_blockscale = torch.empty( + (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + ) + + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + sf_w2_k = round_up(k, 128) + sf_w2_n = round_up(n // quant_blocksize, 4) + w2_blockscale = torch.empty( + (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn + ) + + w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) + w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) + w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) + + for expert in range(e): + w1_amax = torch.abs(w1).max().to(torch.float32) + w2_amax = torch.abs(w2).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_q[expert], w1_blockscale[expert] = scaled_fp4_quant( + w1[expert], w1_gs[expert] + ) + + w2_q[expert], w2_blockscale[expert] = scaled_fp4_quant( + w2[expert], w2_gs[expert] + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + use_grouped_topk=False, + renormalize=False, + ) + + a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) + # strides for the cutlass moe_fp4 kernel + ab_strides_13 = torch.full( + (e,), w1_q.shape[2] * 2, dtype=torch.int64, device=w1_q.device + ) + c_strides_13 = torch.full( + (e,), w1_q.shape[1], dtype=torch.int64, device=w1_q.device + ) + ab_strides_2 = torch.full( + (e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device + ) + c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device) + cutlass_output = cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_q, + w1_blockscale=w1_blockscale, + w1_alphas=(1 / w1_gs), + a2_gscale=a2_gs, + w2_fp4=w2_q, + w2_blockscale=w2_blockscale, + w2_alphas=(1 / w2_gs), + ab_strides_13=ab_strides_13, + ab_strides_2=ab_strides_2, + c_strides_13=c_strides_13, + c_strides_2=c_strides_2, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=e, + device=a.device, + ) + + # Reference check: + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) + a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize, + ) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=w1.dtype, + device=w1.device, + block_size=quant_blocksize, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=quant_blocksize, + ) + + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + + torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 75c929662..09f8f529f 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -210,6 +210,7 @@ set(SOURCES "csrc/gemm/fp8_blockwise_gemm_kernel.cu" "csrc/gemm/fp8_gemm_kernel.cu" "csrc/gemm/int8_gemm_kernel.cu" + "csrc/gemm/nvfp4_expert_quant.cu" "csrc/gemm/nvfp4_quant_entry.cu" "csrc/gemm/nvfp4_quant_kernels.cu" "csrc/gemm/nvfp4_scaled_mm_entry.cu" @@ -222,6 +223,7 @@ set(SOURCES "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" + "csrc/moe/nvfp4_blockwise_moe.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/prepare_moe_input.cu" "csrc/moe/ep_moe_reorder_kernel.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 2a28eb103..0936a88c8 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -132,6 +132,20 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " Tensor! output_scale, Tensor! input_scale) -> ()"); m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + // Compute NVFP4 experts quantization. + m.def( + "scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts) -> ()"); + m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); + + m.def( + "cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b," + "Tensor a_blockscale, Tensor b_blockscale, Tensor alphas," + "Tensor ab_strides, Tensor c_strides, Tensor problem_sizes," + " Tensor expert_offsets, Tensor sf_offsets) -> ()"); + m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + /* * From csrc/moe */ @@ -161,9 +175,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "expert_offsets, Tensor workspace) -> ()"); m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm); m.def( - "prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor " - "input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()"); + "prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1," + " Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> " + "()"); m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input); + + m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()"); + m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); + /* * From csrc/speculative */ diff --git a/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu b/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu new file mode 100644 index 000000000..32cffb388 --- /dev/null +++ b/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu @@ -0,0 +1,431 @@ +#include +#include +#include +#include +#include + +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + // PTX instructions used here requires sm100a. +#if CUDA_VERSION >= 12080 +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), + "f"(array[1]), + "f"(array[2]), + "f"(array[3]), + "f"(array[4]), + "f"(array[5]), + "f"(array[6]), + "f"(array[7])); + return val; +#else + return 0; +#endif +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + // PTX instructions used here requires sm100a. +#if CUDA_VERSION >= 12080 +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + return val; +#else + return 0; +#endif +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, + int32_t numCols, + Type const* in, + float const* SFScale, + uint32_t* out, + uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, + int n_experts) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + // Find index within the experts. + int rowIdx_in_expert = 0; + int expert_idx = 0; + for (int i = 0; i < n_experts; i++) { + if (rowIdx >= input_offset_by_experts[i] && rowIdx < input_offset_by_experts[i + 1]) { + rowIdx_in_expert = rowIdx - input_offset_by_experts[i]; + expert_idx = i; + break; + } + } + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + // The actual output_scales dim is computed from the padded numCols. + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + + auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void quant_impl( + void* output, + void* output_scale, + void* input, + void* input_global_scale, + void* input_offset_by_experts, + void* output_scale_offset_by_experts, + int m_topk, + int k, + int n_experts, + cudaStream_t stream) { + // TODO: this multiProcessorCount should be cached. + int device; + cudaGetDevice(&device); + int multiProcessorCount; + cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device); + + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(k / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM)); + + cvt_fp16_to_fp4<<>>( + m_topk, + k, + reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts); +} + +/*Quantization entry for fp4 experts quantization*/ +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); + +// constexpr auto FP8 = at::ScalarType::Float8_e4m3fn; +constexpr auto HALF = at::ScalarType::Half; +constexpr auto BF16 = at::ScalarType::BFloat16; +constexpr auto FLOAT = at::ScalarType::Float; +constexpr auto INT = at::ScalarType::Int; +constexpr auto UINT8 = at::ScalarType::Byte; + +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor const& input, + torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + CHECK_INPUT(output, "output must be a CUDA tensor"); + CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); + CHECK_INPUT(input, "input must be a CUDA tensor"); + CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); + CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor"); + CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor"); + + TORCH_CHECK(output.dim() == 2); + TORCH_CHECK(output_scale.dim() == 2); + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(input_global_scale.dim() == 1); + TORCH_CHECK(input_offset_by_experts.dim() == 1); + TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + + TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + // output is uint8 (two nvfp4 values are packed into one uint8) + // output_scale is int32 (four fp8 values are packed into one int32) + TORCH_CHECK(output.scalar_type() == UINT8); + TORCH_CHECK(output_scale.scalar_type() == INT); + + const int BLOCK_SIZE = 16; + auto m_topk = input.size(0); + auto k = input.size(1); + TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto n_experts = input_global_scale.size(0); + TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output.size(0) == m_topk); + TORCH_CHECK(output.size(1) == k / 2); + int scales_k = k / BLOCK_SIZE; + // 4 means the swizzle requirement by nvidia nvfp4. + int padded_k = (scales_k + (4 - 1)) / 4 * 4; + // 4 means 4 fp8 values are packed into one int32 + TORCH_CHECK(output_scale.size(1) * 4 == padded_k); + + auto in_dtype = input.dtype(); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); + if (in_dtype == at::ScalarType::Half) { + quant_impl( + output.data_ptr(), + output_scale.data_ptr(), + input.data_ptr(), + input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), + m_topk, + k, + n_experts, + stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + quant_impl<__nv_bfloat16>( + output.data_ptr(), + output_scale.data_ptr(), + input.data_ptr(), + input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), + m_topk, + k, + n_experts, + stream); + } else { + TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); + } +} diff --git a/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu b/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu index 60fda7dce..8b6a0a275 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu @@ -18,6 +18,15 @@ limitations under the License. #if defined ENABLE_NVFP4 && ENABLE_NVFP4 void scaled_fp4_quant_sm100a( torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf); + +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor const& input, + torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); + #endif void scaled_fp4_quant( @@ -27,3 +36,17 @@ void scaled_fp4_quant( #endif TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); } + +void scaled_fp4_experts_quant( + torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor const& input, + torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return scaled_fp4_experts_quant_sm100a( + output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); +} diff --git a/sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu b/sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu new file mode 100644 index 000000000..c68ca552d --- /dev/null +++ b/sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu @@ -0,0 +1,471 @@ +#include +#include +#include +#include +#include + +#include + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +using namespace cute; + +template < + typename ElementAB, + typename ElementC, + typename ElementSF, + typename ElementAccumulator, + typename LayoutSFA, + typename LayoutSFB, + typename ScaleConfig> +__global__ void __get_group_gemm_starts( + ElementAB** a_offsets, + ElementAB** b_offsets, + ElementC** out_offsets, + ElementSF** a_scales_offsets, + ElementSF** b_scales_offsets, + ElementAccumulator** alpha_offsets, + LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, + ElementAB* a_base_as_int, + ElementAB* b_base_as_int, + ElementC* out_base_as_int, + ElementSF* a_scales_base_as_int, + ElementSF* b_scales_base_as_int, + ElementAccumulator* alphas_base_as_int, + const int32_t* expert_offsets, + const int32_t* sf_offsets, + const int32_t* problem_sizes_as_shapes, + const int K, + const int N) { + int64_t expert_id = threadIdx.x; + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + // Originally int32_t but upcasting to int64_t to avoid overflow + // during offset calculations + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t sf_offset = static_cast(sf_offsets[expert_id]); + // size for block in block scale. + int64_t group_size = 16; + int64_t m = static_cast(problem_sizes_as_shapes[expert_id * 3]); + int64_t n = static_cast(problem_sizes_as_shapes[expert_id * 3 + 1]); + int64_t k = static_cast(problem_sizes_as_shapes[expert_id * 3 + 2]); + assert((m >= 0 && n == N && k == K && k % 2 == 0) && "unexpected problem sizes"); + + int64_t half_k = static_cast(k / 2); + int64_t group_k = static_cast(k / group_size); + // Shape of A as uint8/byte = [M, K // 2] + // Shape of B as uint8/byte = [E, N, K // 2] + a_offsets[expert_id] = a_base_as_int + expert_offset * half_k; + + b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k; + // Shape of C = [M, N] + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + // Shape of a_scale = [sum(sf_sizes), K // group_size] + a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k; + + assert((reinterpret_cast(a_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment"); + + // Shape of B scale = [E, N, K // group_size] + b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k; + assert((reinterpret_cast(b_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment"); + // Shape of alpha = [E] + alpha_offsets[expert_id] = alphas_base_as_int + expert_id; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA( + cute::make_shape(static_cast(m), static_cast(n), static_cast(k), 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB( + cute::make_shape(static_cast(m), static_cast(n), static_cast(k), 1)); +} + +#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE( \ + ELEMENT_AB_TYPE, SF_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + __get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(a_starts.data_ptr()), \ + static_cast(b_starts.data_ptr()), \ + static_cast(out_starts.data_ptr()), \ + static_cast(a_scales_starts.data_ptr()), \ + static_cast(b_scales_starts.data_ptr()), \ + static_cast(alpha_starts.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast(alphas.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(sf_offsets.data_ptr()), \ + static_cast(problem_sizes.data_ptr()), \ + K, \ + N); \ + } + +template +void run_get_group_gemm_starts( + const torch::Tensor& a_starts, + const torch::Tensor& b_starts, + const torch::Tensor& out_starts, + const torch::Tensor& a_scales_starts, + const torch::Tensor& b_scales_starts, + const torch::Tensor& alpha_starts, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + /*these are used for their base addresses*/ + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& out_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& alphas, + torch::Tensor const& expert_offsets, + torch::Tensor const& sf_offsets, + torch::Tensor const& problem_sizes, + int M, + int N, + int K) { + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + TORCH_CHECK(out_tensors.size(1) == N, "Output tensor shape doesn't match expected shape"); + TORCH_CHECK( + K / 2 == b_tensors.size(2), + "b_tensors(dim = 2) and a_tensors(dim = 1) trailing" + " dimension must match"); + if (false) { + } + //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, + // ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE( + cutlass::float_e2m1_t, + cutlass::float_ue4m3_t, + torch::kBFloat16, + cutlass::bfloat16_t, + LayoutSFA, + LayoutSFB, + ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE( + cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kFloat16, half, LayoutSFA, LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +template +void run_fp4_blockwise_scaled_group_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& a_blockscale, + const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, + const torch::Tensor& ab_strides, + const torch::Tensor& c_strides, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& sf_offsets, + int M, + int N, + int K) { + using ProblemShape = cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm100; + using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag + using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + + using ClusterShape = Shape<_1, _1, _1>; + struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _128, _128>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + }; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + EpilogueOperatorClass, + typename MMA1SMConfig::MmaTileShape, + ClusterShape, + Shape<_128, _64>, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutC*, + AlignmentC, + ElementD, + LayoutC*, + AlignmentD, + typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + MainloopOperatorClass, + ElementA, + LayoutA*, + AlignmentA, + ElementB, + LayoutB*, + AlignmentB, + ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = Gemm1SM; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); + + run_get_group_gemm_starts( + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb, + a, + b, + output, + a_blockscale, + b_blockscales, + alphas, + expert_offsets, + sf_offsets, + problem_sizes, + M, + N, + K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams< + typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.get_device(); + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(ab_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(ab_strides.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") +#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +void cutlass_fp4_group_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& a_blockscale, + const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, + const torch::Tensor& ab_strides, + const torch::Tensor& c_strides, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& sf_offsets) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + + constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; + constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + // Input validation + CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); + CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale"); + CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales"); + CHECK_INPUT(alphas, at::ScalarType::Float, "alphas"); + + TORCH_CHECK( + a_blockscale.dim() == 2, + "expected a_blockscale to be of shape [num_experts, rounded_m," + " k // group_size], observed rank: ", + a_blockscale.dim()) + TORCH_CHECK( + b_blockscales.dim() == 3, + "expected b_blockscale to be of shape: " + " [num_experts, n, k // group_size], observed rank: ", + b_blockscales.dim()) + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have the shape (num_experts, 3)"); + TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32."); + + int M = static_cast(a.size(0)); + int N = static_cast(b.size(1)); + int E = static_cast(b.size(0)); + int K = static_cast(2 * b.size(2)); + + if (output.scalar_type() == torch::kBFloat16) { + run_fp4_blockwise_scaled_group_mm( + output, + a, + b, + a_blockscale, + b_blockscales, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + sf_offsets, + M, + N, + K); + } else { + run_fp4_blockwise_scaled_group_mm( + output, + a, + b, + a_blockscale, + b_blockscales, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + sf_offsets, + M, + N, + K); + } +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel, sgl-kernel must " + "be compiled with ENABLE_NVFP4 for SM100+ and CUDA " + "12.8 or above."); +#endif +} diff --git a/sgl-kernel/csrc/moe/prepare_moe_input.cu b/sgl-kernel/csrc/moe/prepare_moe_input.cu index 5f3010301..06237b56e 100644 --- a/sgl-kernel/csrc/moe/prepare_moe_input.cu +++ b/sgl-kernel/csrc/moe/prepare_moe_input.cu @@ -4,6 +4,8 @@ #include +#include "cutlass/array.h" + constexpr uint64_t THREADS_PER_EXPERT = 512; __global__ void compute_problem_sizes( @@ -11,9 +13,9 @@ __global__ void compute_problem_sizes( int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, - const int topk_length, - const int n, - const int k) { + const int64_t topk_length, + const int64_t n, + const int64_t k) { int expert_id = blockIdx.x; int occurrences = 0; @@ -26,11 +28,11 @@ __global__ void compute_problem_sizes( if (threadIdx.x == 0) { int final_occurrences = atomic_buffer[expert_id]; problem_sizes1[expert_id * 3] = final_occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes1[expert_id * 3 + 1] = static_cast(2 * n); + problem_sizes1[expert_id * 3 + 2] = static_cast(k); problem_sizes2[expert_id * 3] = final_occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; + problem_sizes2[expert_id * 3 + 1] = static_cast(k); + problem_sizes2[expert_id * 3 + 2] = static_cast(n); } } @@ -38,7 +40,7 @@ __global__ void compute_expert_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, int32_t* atomic_buffer, - const int num_experts) { + const int64_t num_experts) { int32_t tot_offset = 0; expert_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { @@ -48,13 +50,34 @@ __global__ void compute_expert_offsets( } } +__global__ void compute_expert_blockscale_offsets( + const int32_t* __restrict__ problem_sizes1, + int32_t* expert_offsets, + int32_t* blockscale_offsets, + int32_t* atomic_buffer, + const int64_t num_experts) { + int32_t tot_offset = 0; + int32_t tot_rounded_offset = 0; + expert_offsets[0] = 0; + blockscale_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + atomic_buffer[i] = tot_offset; + int num_tokens = problem_sizes1[i * 3]; + int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128; + tot_offset += num_tokens; + tot_rounded_offset += rounded_num_tokens; + expert_offsets[i + 1] = tot_offset; + blockscale_offsets[i + 1] = tot_rounded_offset; + } +} + __global__ void compute_arg_sorts( - const int* __restrict__ topk_ids, + const int32_t* __restrict__ topk_ids, int32_t* input_permutation, int32_t* output_permutation, int32_t* atomic_buffer, - const int topk_length, - const int topk) { + const int64_t topk_length, + const int64_t topk) { int expert_id = blockIdx.x; for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { @@ -69,6 +92,7 @@ __global__ void compute_arg_sorts( void get_moe_prepare_input_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + const std::optional& blockscale_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, @@ -80,8 +104,10 @@ void get_moe_prepare_input_caller( auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); - int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - compute_problem_sizes<<>>( + uint32_t num_threads = static_cast(min(THREADS_PER_EXPERT, topk_ids.numel())); + uint32_t num_blocks = static_cast(num_experts); + + compute_problem_sizes<<>>( static_cast(topk_ids.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), @@ -89,12 +115,21 @@ void get_moe_prepare_input_caller( topk_ids.numel(), n, k); - compute_expert_offsets<<<1, 1, 0, stream>>>( - static_cast(problem_sizes1.data_ptr()), - static_cast(expert_offsets.data_ptr()), - static_cast(atomic_buffer.data_ptr()), - num_experts); - compute_arg_sorts<<>>( + if (blockscale_offsets.has_value()) { + compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(blockscale_offsets.value().data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts); + } else { + compute_expert_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts); + } + compute_arg_sorts<<>>( static_cast(topk_ids.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), @@ -106,6 +141,7 @@ void get_moe_prepare_input_caller( void prepare_moe_input( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + const std::optional& blockscale_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, @@ -117,6 +153,7 @@ void prepare_moe_input( get_moe_prepare_input_caller( topk_ids, expert_offsets, + blockscale_offsets, problem_sizes1, problem_sizes2, input_permutation, @@ -126,3 +163,92 @@ void prepare_moe_input( k); return; } + +template +__global__ void shuffleRowsKernel( + const T* input, + const int32_t* dst2src_map, + T* output, + int64_t num_src_rows, + int64_t num_dst_rows, + int64_t num_cols) { + int64_t dest_row_idx = blockIdx.x; + int64_t const source_row_idx = dst2src_map[dest_row_idx]; + + if (blockIdx.x < num_dst_rows) { + // Load 128-bits per thread + constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8; + using DataElem = cutlass::Array; + + // Duplicate and permute rows + auto const* source_row_ptr = reinterpret_cast(input + source_row_idx * num_cols); + auto* dest_row_ptr = reinterpret_cast(output + dest_row_idx * num_cols); + + auto const start_offset = threadIdx.x; + auto const stride = blockDim.x; + auto const num_elems_in_col = num_cols / ELEM_PER_THREAD; + + for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { + dest_row_ptr[elem_index] = source_row_ptr[elem_index]; + } + } +} + +#define DECLARE_SHUFFLE_ROWS(T) \ + __global__ void shuffleRowsKernel( \ + const T* input, \ + const int32_t* dst2src_map, \ + T* output, \ + int64_t num_src_rows, \ + int64_t num_dest_rows, \ + int64_t num_cols); + +DECLARE_SHUFFLE_ROWS(float); +DECLARE_SHUFFLE_ROWS(half); +DECLARE_SHUFFLE_ROWS(__nv_bfloat16); +DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3); +DECLARE_SHUFFLE_ROWS(uint8_t); + +#define SHUFFLE_ROWS(T) \ + shuffleRowsKernel<<>>( \ + reinterpret_cast(input), \ + static_cast(dst2src_map.data_ptr()), \ + reinterpret_cast(output), \ + num_src_rows, \ + num_dst_rows, \ + num_cols) + +#define DTYPE_DISPATCH_CASE(T, CUDA_T) \ + case T: \ + SHUFFLE_ROWS(CUDA_T); \ + break; + +void shuffle_rows_caller( + const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) { + TORCH_CHECK( + input_tensor.scalar_type() == output_tensor.scalar_type(), + "Input and output tensors must have the same data type"); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + uint32_t blocks = static_cast(output_tensor.size(0)); + uint32_t threads = 256; + int64_t num_dst_rows = output_tensor.size(0); + int64_t num_src_rows = input_tensor.size(0); + int64_t num_cols = input_tensor.size(1); + const void* input = input_tensor.data_ptr(); + void* output = output_tensor.data_ptr(); + switch (input_tensor.scalar_type()) { + DTYPE_DISPATCH_CASE(torch::kFloat16, half); + DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16); + DTYPE_DISPATCH_CASE(torch::kFloat32, float); + DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3); + DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t); + default: + TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!"); + } + return; +} + +void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) { + shuffle_rows_caller(input_tensor, dst2src_map, output_tensor); + return; +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 9c3432f46..f66e754fd 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -232,6 +232,7 @@ void fp8_blockwise_scaled_grouped_mm( void prepare_moe_input( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + const std::optional& blockscale_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, @@ -251,6 +252,29 @@ void ep_moe_pre_reorder( int64_t topk, bool use_per_token_if_dynamic); +void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor); + +void cutlass_fp4_group_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& a_blockscale, + const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, + const torch::Tensor& ab_strides, + const torch::Tensor& c_strides, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& sf_offsets); + +void scaled_fp4_experts_quant( + torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor const& input, + torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); + /* * From csrc/speculative */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 002d8d394..a7e371456 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -38,14 +38,17 @@ from sgl_kernel.gemm import ( int8_scaled_mm, qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm, + scaled_fp4_experts_quant, scaled_fp4_quant, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, + shuffle_rows, ) from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda from sgl_kernel.moe import ( + cutlass_fp4_group_mm, ep_moe_pre_reorder, fp8_blockwise_scaled_grouped_mm, moe_align_block_size, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 113542436..48a21ee8b 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -241,3 +241,80 @@ def qserve_w4a8_per_group_gemm( in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats ) return out_feats + + +def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): + output_tensor = torch.empty( + output_tensor_shape, + device=input_tensor.device, + dtype=input_tensor.dtype, + ) + torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor) + return output_tensor + + +def scaled_fp4_experts_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + topk: int, + expert_map: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + packed MoE Inputs. + Args: + input: The input tensor to be quantized to FP4 + expert_map: The expert map tensor + input_global_scale: A scalar scaling factor for the entire tensor. + expert_offsets: The expert offsets tensor + blockscale_offsets: The blockscale offsets tensor + Outputs: + output: The quantized tensor in FP4 + output_scales: The blockscale tensor in FP8-E4M3 + """ + assert ( + input_tensor.ndim == 2 + ), f"input.ndim needs to be == 2, but got {input_tensor.ndim}." + if expert_map is not None: + (m, k) = input_tensor.shape + output_tensor_shape = (m * topk, k) + input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape) + m_numtopk, k = input_tensor.shape + # Control the maximum number of tokens per expert supported by the + # NVFP4 MoE Expert Quantization. This is used to prevent the kernel + # from running out of memory. This value can also be increased to support + # larger models. + import os + + MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536) + assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, ( + f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" + f"{MAX_TOKENS_PER_EXPERT})" + f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" + f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value." + ) + scales_k = k // 16 + padded_k = (scales_k + (4 - 1)) // 4 + + # output is uint8 and packed fp4 values + output = torch.empty( + m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 + ) + output_scales = torch.empty( + MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device, + ) + torch.ops.sgl_kernel.scaled_fp4_experts_quant.default( + output, + output_scales, + input_tensor, + input_global_scale, + expert_offsets, + blockscale_offsets, + ) + output_scales = output_scales.view(torch.float8_e4m3fn) + return output, output_scales diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 27808494d..b9e4adcc8 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch @@ -138,10 +140,12 @@ def prepare_moe_input( num_experts, n, k, + blockscale_offsets: Optional[torch.Tensor] = None, ): torch.ops.sgl_kernel.prepare_moe_input.default( topk_ids, expert_offsets, + blockscale_offsets, problem_sizes1, problem_sizes2, input_permutation, @@ -150,3 +154,54 @@ def prepare_moe_input( n, k, ) + + +def cutlass_fp4_group_mm( + a_fp4, + b_fp4, + a_blockscale, + b_blockscale, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + blockscale_offsets, + out_dtype, + device, +): + """ + An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs + the gemms for each combination based on the specified problem sizes. + + This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. + - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized + input and expert weights. + - a_/b_scales: The blockscales in FP8-E4M3 precision + - ab_strides/c_strides: Strides for the a/b tensors between rows. + - expert_offsets/sf_offsets: Indices that mark at which token index + each expert begins its computation. The number of tokens + computed with expert E is expert_offsets[E + 1] - + expert_offsets[E] And the sf_size per expert is + sf_offset[E+1] - sf_offset[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + """ + m_topk = a_fp4.shape[0] + n = b_fp4.shape[1] + c_shape = (m_topk, n) + c = torch.empty(c_shape, device=device, dtype=out_dtype) + torch.ops.sgl_kernel.cutlass_fp4_group_mm.default( + c, + a_fp4, + b_fp4, + a_blockscale, + b_blockscale, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + blockscale_offsets, + ) + return c.to(dtype=out_dtype)