From 0df6765c83e2ea1263295812e0979aa6801377c0 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Thu, 5 Jun 2025 13:13:14 -0700 Subject: [PATCH] [CUTLASS-FP4-MOE] Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (#6887) Signed-off-by: Pavani Majety --- python/sglang/srt/layers/moe/cutlass_moe.py | 93 ++++------ .../srt/layers/moe/cutlass_moe_params.py | 169 ++++++++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 4 +- python/sglang/test/test_cutlass_moe.py | 6 +- python/sglang/test/test_fp4_moe.py | 19 +- sgl-kernel/python/sgl_kernel/moe.py | 18 +- 6 files changed, 230 insertions(+), 79 deletions(-) create mode 100644 python/sglang/srt/layers/moe/cutlass_moe_params.py diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 199982c9b..5b90c4fb4 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch +from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams from sglang.srt.utils import is_cuda _is_cuda = is_cuda() @@ -18,11 +19,12 @@ if _is_cuda: fp8_blockwise_scaled_grouped_mm, prepare_moe_input, scaled_fp4_experts_quant, + shuffle_rows, silu_and_mul, ) -def cutlass_fused_experts( +def cutlass_fused_experts_fp8( a: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, @@ -223,17 +225,10 @@ def cutlass_moe_fp4( w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, w2_alphas: torch.Tensor, - ab_strides_13: torch.Tensor, - ab_strides_2: torch.Tensor, - c_strides_13: torch.Tensor, - c_strides_2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - m: int, - n: int, - k: int, - e: int, - device: torch.device, + params: CutlassMoEParams, + apply_router_weight_on_input: bool = False, ): """ MoE implementation for FP4 Inputs @@ -291,77 +286,70 @@ def cutlass_moe_fp4( e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape - assert e_w1 == e_w2 and e_w1 == e, ( + assert e_w1 == e_w2 and e_w1 == params.num_experts, ( "Number of experts must match", " between weights.", ) assert ( - k_a // 2 == half_k_w1 and k == k_w2 + k_a // 2 == half_k_w1 and params.hidden_size == k_w2 ), "Hidden size mismatch between a, w1 and w2" - assert nx2_w1 == n * 2 and half_n_w2 == n // 2, "mismatch in " "expected `n`" - assert m == m_a, "input shape mismatch" + assert ( + nx2_w1 == params.intermediate_size_per_partition * 2 + and half_n_w2 == params.intermediate_size_per_partition // 2 + ), ("mismatch in " "expected `n`") assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert ( - topk_weights.shape[0] == m and topk_ids.shape[0] == m - ), "topk must be provided for each row of a" out_dtype = a.dtype num_topk = topk_ids.shape[1] - - expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) - # Problem size: (num_experts, (m,2n,k)) - problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) - # Problem size: (num_experts, (m,n,k)) - problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device) - + device = a.device a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - - # problem shapes should have [m, n, k] - # Note that problem sizes are based on logical number of elements. - blockscale_offsets = torch.empty(e + 1, dtype=torch.int32, device=device) prepare_moe_input( topk_ids, - expert_offsets, - problem_sizes1, - problem_sizes2, + params.expert_offsets, + params.problem_sizes1, + params.problem_sizes2, a_map, c_map, - e, - n, - k, - blockscale_offsets, + params.num_experts, + params.intermediate_size_per_partition, + params.hidden_size, + params.blockscale_offsets, ) rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant( - a, a1_gscale, expert_offsets, blockscale_offsets, num_topk, expert_map=a_map + a, + a1_gscale, + params.expert_offsets, + params.blockscale_offsets, + num_topk, + expert_map=a_map, ) - c1 = cutlass_fp4_group_mm( rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale, w1_alphas, - ab_strides_13, - c_strides_13, - problem_sizes1, - expert_offsets[:-1], - blockscale_offsets[:-1], out_dtype, device, + params.to_gemm1_args(), ) del rep_a_fp4, rep_a_blockscale + # hidden size dimension is split to one halfpytho sized tensor. intermediate = torch.empty( - (m * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype + (m_a * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype ) - silu_and_mul(c1, intermediate) int_fp4, int_blockscale = scaled_fp4_experts_quant( - intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk + intermediate, + a2_gscale, + params.expert_offsets, + params.blockscale_offsets, + num_topk, ) c2 = cutlass_fp4_group_mm( int_fp4, @@ -369,16 +357,13 @@ def cutlass_moe_fp4( int_blockscale, w2_blockscale, w2_alphas, - ab_strides_2, - c_strides_2, - problem_sizes2, - expert_offsets[:-1], - blockscale_offsets[:-1], out_dtype, device, + params.to_gemm2_args(), ) del int_fp4, int_blockscale - out = ( - c2[c_map].view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half() - ).sum(dim=1) - return out.to(dtype=out_dtype) + c2 = shuffle_rows(c2, c_map, (m_a * num_topk, params.hidden_size)) + c2 = c2.view(m_a, num_topk, params.hidden_size) + if not apply_router_weight_on_input: + c2 = c2 * topk_weights.view(m_a, num_topk, 1).to(out_dtype) + return c2.sum(dim=1).to(out_dtype) diff --git a/python/sglang/srt/layers/moe/cutlass_moe_params.py b/python/sglang/srt/layers/moe/cutlass_moe_params.py new file mode 100644 index 000000000..f3de60e04 --- /dev/null +++ b/python/sglang/srt/layers/moe/cutlass_moe_params.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import Optional + +import torch + + +class CutlassMoEType(Enum): + """ + Enum for the different types of cutlass moe operations + that are currently supported in SGLang. + """ + + BlockscaledFP8 = auto() + BlockscaledFP4 = auto() + + +@dataclass +class CutlassMoEParams: + """ + Parameters for the cutlass moe operation. + """ + + # Type as defined above + cutlass_moe_type: CutlassMoEType + + # Strides for activations, weights and output in logical number of elements. + # The activations & output stride is the number of elements to the next row. + # The weights stride is the number of elements to the next row per expert. + # For example, if the weight is [e, n, k], then the b_stride is a tensor of + # shape [e] with each element being k. Similarly for activations, if the + # shape is [m, k], then the a_stride has shape [e] with each value k. + # Similarly for output, if the output is [m, n], then the c_stride is a + # tensor of shape [e] with each element being k. + + # Note: cutlass_fp4_group_mm is designed to accept the strides of + # activations and weights to be the same, so it is passed in as a single + # tensor. + # ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides] + # ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides] + # c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides] + # c_strides_2: [e] dtype: int64 [Gemm 2: Output Strides] + ab_strides_13: torch.Tensor + ab_strides_2: torch.Tensor + c_strides_13: torch.Tensor + c_strides_2: torch.Tensor + + # m: Total number of tokens + # n: intermediate size per partition + # k: hidden size per expert + # e: Number of experts + # device: Device to run computation on and store tensors + m: int + intermediate_size_per_partition: int + hidden_size: int + num_experts: int + device: torch.device + + # Pointers container for calculating offsets of the input activations for each expert + # a_ptrs: [e] dtype: int64 + a_ptrs: torch.Tensor + + # Pointers container for calculating offsets of the input weights for each expert + # b_ptrs: [e] dtype: int64 + b_ptrs: torch.Tensor + + # Pointers container for calculating offsets of the output activations for each expert + # out_ptrs: [e] dtype: int64 + out_ptrs: torch.Tensor + # Pointers container for calculating offsets of the input scales for each expert + # a_scales_ptrs: [e] dtype: int64 + # b_scales_ptrs: [e] dtype: int64 + a_scales_ptrs: torch.Tensor + b_scales_ptrs: torch.Tensor + + # Offsets that mark at which token index each expert begins its computation + # The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E] + # expert_offsets: [e+1] dtype: int32 + expert_offsets: torch.Tensor + + # Problem size: (num_experts, (m,2n,k)) for first GEMM + # problem_sizes1: [e, 3] dtype: int32 + # Problem size: (num_experts, (m,n,k)) for second GEMM + # problem_sizes2: [e, 3] dtype: int32 + problem_sizes1: torch.Tensor + problem_sizes2: torch.Tensor + # Similar to expert_offsets, but for blockscales for FP4 blockscaled Group GEMM + blockscale_offsets: Optional[torch.Tensor] = None + + def __init__( + self, + cutlass_moe_type: CutlassMoEType, + device: torch.device, + num_experts: int, + intermediate_size_per_partition: int, + hidden_size: int, + ): + self.cutlass_moe_type = cutlass_moe_type + self.device = device + self.num_experts = num_experts + self.intermediate_size_per_partition = intermediate_size_per_partition + self.hidden_size = hidden_size + self.n = self.intermediate_size_per_partition + self.k = self.hidden_size + self.e = self.num_experts + self.ab_strides_13 = torch.full( + (self.e,), self.k, dtype=torch.int64, device=self.device + ) + self.ab_strides_2 = torch.full( + (self.e,), self.n, dtype=torch.int64, device=self.device + ) + self.c_strides_13 = torch.full( + (self.e,), 2 * self.n, dtype=torch.int64, device=self.device + ) + self.c_strides_2 = torch.full( + (self.e,), self.k, dtype=torch.int64, device=self.device + ) + self.expert_offsets = torch.empty( + (self.e + 1,), dtype=torch.int32, device=self.device + ) + self.problem_sizes1 = torch.empty( + (self.e, 3), dtype=torch.int32, device=self.device + ) + self.problem_sizes2 = torch.empty( + (self.e, 3), dtype=torch.int32, device=self.device + ) + if self.cutlass_moe_type == CutlassMoEType.BlockscaledFP4: + self.blockscale_offsets = torch.empty( + (self.e + 1,), dtype=torch.int32, device=self.device + ) + else: + self.blockscale_offsets = None + self.a_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device) + self.b_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device) + self.out_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device) + self.a_scales_ptrs = torch.empty( + (self.e,), dtype=torch.int64, device=self.device + ) + self.b_scales_ptrs = torch.empty( + (self.e,), dtype=torch.int64, device=self.device + ) + + def to_gemm1_args(self) -> dict: + return { + "ab_strides": self.ab_strides_13, + "c_strides": self.c_strides_13, + "problem_sizes": self.problem_sizes1, + "expert_offsets": self.expert_offsets[:-1], + "blockscale_offsets": self.blockscale_offsets[:-1], + # "a_ptrs": self.a_ptrs, + # "b_ptrs": self.b_ptrs, + # "out_ptrs": self.out_ptrs, + # "a_scales_ptrs": self.a_scales_ptrs, + # "b_scales_ptrs": self.b_scales_ptrs, + } + + def to_gemm2_args(self) -> dict: + return { + "ab_strides": self.ab_strides_2, + "c_strides": self.c_strides_2, + "problem_sizes": self.problem_sizes2, + "expert_offsets": self.expert_offsets[:-1], + "blockscale_offsets": self.blockscale_offsets[:-1], + # "a_ptrs": self.a_ptrs, + # "b_ptrs": self.b_ptrs, + # "out_ptrs": self.out_ptrs, + # "a_scales_ptrs": self.a_scales_ptrs, + # "b_scales_ptrs": self.b_scales_ptrs, + } diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 0d2596ad9..3d8269a63 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -982,9 +982,9 @@ class Fp8MoEMethod: and self.block_quant and is_sm100_supported() ): - from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts + from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 - return cutlass_fused_experts( + return cutlass_fused_experts_fp8( x, layer.w13_weight.transpose(1, 2), layer.w2_weight.transpose(1, 2), diff --git a/python/sglang/test/test_cutlass_moe.py b/python/sglang/test/test_cutlass_moe.py index 0a8f58d3f..496e6d487 100755 --- a/python/sglang/test/test_cutlass_moe.py +++ b/python/sglang/test/test_cutlass_moe.py @@ -6,7 +6,7 @@ import triton # Added import import triton.testing # Added import from transformers import AutoConfig -from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts +from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -125,7 +125,7 @@ def run_test(tp_size, batch_size, model_config, check=False): problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda") # --- Lambdas for Benchmarking --- - cutlass_lambda = lambda: cutlass_fused_experts( + cutlass_lambda = lambda: cutlass_fused_experts_fp8( x, w1.transpose(1, 2), # Transposed w2.transpose(1, 2), # Transposed @@ -193,7 +193,7 @@ def run_test(tp_size, batch_size, model_config, check=False): print("Running correctness check...") with torch.no_grad(): # Run CUTLASS version (requires transposed weights) - y_cutlass = cutlass_fused_experts( + y_cutlass = cutlass_fused_experts_fp8( x, w1.transpose(1, 2), # Transposed w2.transpose(1, 2), # Transposed diff --git a/python/sglang/test/test_fp4_moe.py b/python/sglang/test/test_fp4_moe.py index df3d1f7c1..7e3de278c 100644 --- a/python/sglang/test/test_fp4_moe.py +++ b/python/sglang/test/test_fp4_moe.py @@ -5,6 +5,7 @@ from sgl_kernel import scaled_fp4_quant from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 +from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.topk import select_experts if torch.cuda.get_device_capability() < (10, 0): @@ -179,6 +180,13 @@ def test_cutlass_fp4_moe_no_graph( (e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device ) c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device) + params = CutlassMoEParams( + CutlassMoEType.BlockscaledFP4, + device=a.device, + num_experts=e, + intermediate_size_per_partition=n, # n + hidden_size=k, + ) # k cutlass_output = cutlass_moe_fp4( a=a, a1_gscale=a1_gs, @@ -189,17 +197,10 @@ def test_cutlass_fp4_moe_no_graph( w2_fp4=w2_q, w2_blockscale=w2_blockscale, w2_alphas=(1 / w2_gs), - ab_strides_13=ab_strides_13, - ab_strides_2=ab_strides_2, - c_strides_13=c_strides_13, - c_strides_2=c_strides_2, topk_weights=topk_weights, topk_ids=topk_ids, - m=m, - n=n, - k=k, - e=e, - device=a.device, + params=params, + apply_router_weight_on_input=False, ) # Reference check: diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 0b60f2e18..b3497f517 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional import torch @@ -184,13 +184,9 @@ def cutlass_fp4_group_mm( a_blockscale, b_blockscale, alphas, - ab_strides, - c_strides, - problem_sizes, - expert_offsets, - blockscale_offsets, out_dtype, device, + params: Dict[str, Any], ): """ An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs @@ -220,10 +216,10 @@ def cutlass_fp4_group_mm( a_blockscale, b_blockscale, alphas, - ab_strides, - c_strides, - problem_sizes, - expert_offsets, - blockscale_offsets, + params["ab_strides"], + params["c_strides"], + params["problem_sizes"], + params["expert_offsets"], + params["blockscale_offsets"], ) return c.to(dtype=out_dtype)