diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py new file mode 100755 index 000000000..9046fc676 --- /dev/null +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -0,0 +1,207 @@ +"""Cutlass MoE kernel.""" + +import functools +import json +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch + +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() +if _is_cuda: + import sgl_kernel + from sgl_kernel import ( + fp8_blockwise_scaled_grouped_mm, + prepare_moe_input, + silu_and_mul, + ) + + +def cutlass_fused_experts( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + a1_strides: torch.Tensor, + c1_strides: torch.Tensor, + a2_strides: torch.Tensor, + c2_strides: torch.Tensor, + workspace: torch.Tensor, + a_ptrs: torch.Tensor, + b_ptrs: torch.Tensor, + out_ptrs: torch.Tensor, + a_scales_ptrs: torch.Tensor, + b_scales_ptrs: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + use_fp8_blockscale: bool = True, +) -> torch.Tensor: + """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations. + + This function implements a Mixture of Experts (MoE) layer with a SwiGLU/SiLU + activation, leveraging custom kernels likely derived from CUTLASS principles + for grouped matrix multiplication (`fp8_blockwise_scaled_grouped_mm`) and + data preparation (`prepare_moe_input`, `silu_and_mul`). + + It handles per-token routing, quantizes input activations to FP8 with + per-token scales, performs the expert computations using FP8 GEMMs with + pre-quantized FP8 weights (per-block scales), applies the SiLU activation, + and combines the results weighted by the router scores. + + Args: + a (torch.Tensor): Input activations. Shape: `(m, k)`, where `m` is the total + number of tokens and `k` is the hidden size. Expected dtype: `torch.half` + or `torch.bfloat16`. + w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM + (up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where + `E` is the number of experts, `k` is the hidden size, and `n*2` is the + intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`. + Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size). + w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM + (down-projection). Expected shape: `(E, n, k)`, where `n` is half the + intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`. + Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size). + w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales). + Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`. + w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales). + Shape: `(E, num_blocks_k, num_blocks_n)`. Dtype: `torch.float32`. + topk_weights (torch.Tensor): Router weights for the selected top-k experts + for each token. Shape: `(m, topk)`. Dtype should ideally match `a`. + topk_ids (torch.Tensor): Indices of the selected top-k experts for each token. + Shape: `(m, topk)`. Dtype: `torch.int32`. + a1_strides (torch.Tensor): Stride information for the first GEMM's 'a' input. + Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`. + Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification + as it's passed as both a_stride and b_stride in the first call. + c1_strides (torch.Tensor): Stride information for the first GEMM's 'c' output. + Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`. + a2_strides (torch.Tensor): Stride information for the second GEMM's 'a' input. + Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`. + Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification + as it's passed as both a_stride and b_stride in the second call. + c2_strides (torch.Tensor): Stride information for the second GEMM's 'c' output. + Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`. + workspace (torch.Tensor): Reusable workspace for the underlying kernel. + a_ptrs (torch.Tensor): Pointers container for calculating offsets of the input activations for each expert. + b_ptrs (torch.Tensor): Pointers container for calculating offsets of the input weights for each expert. + out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert. + a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert. + b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert. + use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with + block scaling. Currently, only `True` is supported. Defaults to `True`. + + Returns: + torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`. + + Raises: + AssertionError: If input shapes, dtypes, or flags are inconsistent or unsupported. + NotImplementedError: If CUDA is not available or `sgl_kernel` is not properly installed. + """ + assert use_fp8_blockscale, "Only support fp8 blockscale for now" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_q.dtype == torch.float8_e4m3fn + assert w2_q.dtype == torch.float8_e4m3fn + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + + if is_cuda: + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) + + out_dtype = a.dtype + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + + topk = topk_ids.size(1) + + a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128) + device = a_q.device + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + prepare_moe_input( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + num_experts, + n, + k, + ) + + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a1_scales = a1_scale[a_map] + + c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) + c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) + + a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int) + w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int) + + fp8_blockwise_scaled_grouped_mm( + c1, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + rep_a_q, + w1_q, + rep_a1_scales, + w1_scale, + a1_strides, + a1_strides, + c1_strides, + a_sf_layout, + w_sf_layout, + problem_sizes1, + expert_offsets[:-1], + workspace, + ) + + intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) + silu_and_mul(c1, intermediate) + + intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) + + fp8_blockwise_scaled_grouped_mm( + c2, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + intemediate_q, + w2_q, + a2_scale, + w2_scale, + a2_strides, + a2_strides, + c2_strides, + a_sf_layout, + w_sf_layout, + problem_sizes2, + expert_offsets[:-1], + workspace, + ) + return ( + c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) + ).sum(dim=1) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 521ba7deb..e43d1f0bb 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( apply_w8a8_block_fp8_linear, cutlass_fp8_supported, input_to_float8, + is_sm100_supported, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod @@ -470,6 +471,7 @@ class Fp8MoEMethod: def __init__(self, quant_config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + self.cutlass_fp8_supported = cutlass_fp8_supported() def create_weights( self, @@ -568,6 +570,63 @@ class Fp8MoEMethod: layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) assert self.quant_config.activation_scheme == "dynamic" + if ( + get_bool_env_var("CUTLASS_MOE") + and self.cutlass_fp8_supported + and is_sm100_supported() + ): + self.ab_strides1 = torch.full( + (num_experts,), + hidden_size, + device=w13_weight.device, + dtype=torch.int64, + ) + self.c_strides1 = torch.full( + (num_experts,), + 2 * intermediate_size, + device=w13_weight.device, + dtype=torch.int64, + ) + self.ab_strides2 = torch.full( + (num_experts,), + intermediate_size, + device=w2_weight.device, + dtype=torch.int64, + ) + self.c_strides2 = torch.full( + (num_experts,), + hidden_size, + device=w2_weight.device, + dtype=torch.int64, + ) + self.workspace = torch.empty( + 90000, device=w13_weight.device, dtype=torch.uint8 + ) + self.a_ptr = torch.empty( + num_experts, device=w13_weight.device, dtype=torch.int64 + ) + self.b_ptr = torch.empty( + num_experts, device=w13_weight.device, dtype=torch.int64 + ) + self.out_ptr = torch.empty( + num_experts, device=w13_weight.device, dtype=torch.int64 + ) + self.a_scales_ptr = torch.empty( + num_experts, device=w13_weight.device, dtype=torch.int64 + ) + self.b_scales_ptr = torch.empty( + num_experts, device=w13_weight.device, dtype=torch.int64 + ) + self.expert_offsets = torch.empty( + num_experts + 1, device=w13_weight.device, dtype=torch.int32 + ) + self.problem_sizes1 = torch.empty( + num_experts, 3, device=w13_weight.device, dtype=torch.int32 + ) + self.problem_sizes2 = torch.empty( + num_experts, 3, device=w13_weight.device, dtype=torch.int32 + ) + else: # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. @@ -913,6 +972,37 @@ class Fp8MoEMethod: if ret is not None: return ret + if ( + get_bool_env_var("CUTLASS_MOE") + and self.cutlass_fp8_supported + and self.block_quant + and is_sm100_supported() + ): + from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts + + return cutlass_fused_experts( + x, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale_inv.transpose(1, 2), + layer.w2_weight_scale_inv.transpose(1, 2), + topk_weights, + topk_ids, + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + self.workspace, + self.a_ptr, + self.b_ptr, + self.out_ptr, + self.a_scales_ptr, + self.b_scales_ptr, + self.expert_offsets, + self.problem_sizes1, + self.problem_sizes2, + use_fp8_blockscale=True, + ) # Expert fusion with FP8 quantization return fused_experts( x, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 0602144e7..05e43fe3f 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -80,6 +80,12 @@ def cutlass_fp8_supported(): return False +def is_sm100_supported(device=None) -> bool: + return (torch.cuda.get_device_capability(device)[0] == 10) and ( + torch.version.cuda >= "12.8" + ) + + def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, diff --git a/python/sglang/test/test_cutlass_moe.py b/python/sglang/test/test_cutlass_moe.py new file mode 100755 index 000000000..0a8f58d3f --- /dev/null +++ b/python/sglang/test/test_cutlass_moe.py @@ -0,0 +1,278 @@ +import argparse +import time + +import torch +import triton # Added import +import triton.testing # Added import +from transformers import AutoConfig + +from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + + +def get_model_config(tp_size: int): + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-R1", trust_remote_code=True + ) + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + + return { + "num_experts": E, + "topk": topk, + "hidden_size": config.hidden_size, + "shard_intermediate_size": shard_intermediate_size, + "dtype": config.torch_dtype, + "block_shape": config.quantization_config["weight_block_size"], + } + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + """Converts tensor to FP8 E4M3, scaling values to fit the range.""" + finfo = torch.finfo(torch.float8_e4m3fn) + # Calculate max absolute value safely + max_val = torch.max(torch.abs(tensor)) + # Avoid division by zero if tensor is all zeros + if max_val == 0: + scale_factor = 1.0 + else: + # Scale factor to bring the max value to finfo.max + scale_factor = finfo.max / max_val + + # Apply scaling + scaled_tensor = tensor * scale_factor + + # Clamp and convert + fp8_tensor = scaled_tensor.clamp(min=finfo.min, max=finfo.max).to( + dtype=torch.float8_e4m3fn + ) + return fp8_tensor + + +def run_test(tp_size, batch_size, model_config, check=False): + print(f"\n--- Batch Size: {batch_size} ---") + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(42) # For reproducible random numbers + + E = model_config["num_experts"] + topk = model_config["topk"] + H = model_config["hidden_size"] + I = model_config["shard_intermediate_size"] + block_shape = model_config["block_shape"] # Tuple (BLOCK_N, BLOCK_K) + dtype = model_config["dtype"] # e.g., torch.bfloat16 + + print( + f"Config: E={E}, topk={topk}, H={H}, I_shard={I}, dtype={dtype}, block_shape={block_shape}" + ) + + # --- Input Data --- + # Use bf16/fp16 for input activation based on model config + x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001 + # --- Weights (Generate in higher precision, then convert to FP8) --- + # Generate weights suitable for FP8 conversion (e.g., scaled appropriately) + w1_hp = ( + torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001 + ) + w2_hp = ( + torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001 + + 0.00001 + ) + + w1 = to_fp8(w1_hp) + w2 = to_fp8(w2_hp) + + # --- Scales for FP8 Weights --- + block_n, block_k = block_shape + # Calculate number of blocks needed + w1_blocks_dim1 = (I + block_n - 1) // block_n + w1_blocks_dim2 = (H + block_k - 1) // block_k + w2_blocks_dim1 = (H + block_n - 1) // block_n + w2_blocks_dim2 = (I // 2 + block_k - 1) // block_k + + # Scales are typically float32 or float16/bfloat16 + scale_dtype = torch.float32 # Or dtype if scales match model dtype + w1_scale = torch.full( + (E, w1_blocks_dim1, w1_blocks_dim2), 1, device="cuda", dtype=scale_dtype + ) # Avoid zero scales + w2_scale = torch.full( + (E, w2_blocks_dim1, w2_blocks_dim2), 1, device="cuda", dtype=scale_dtype + ) # Avoid zero scales + + # --- Routing Information --- + topk_weights = torch.softmax( + torch.rand(batch_size, topk, device="cuda", dtype=dtype), dim=-1 + ) + topk_ids = torch.randint(0, E, (batch_size, topk), dtype=torch.int32, device="cuda") + + a1_strides = torch.full((E,), H, dtype=torch.int64, device="cuda") + c1_strides = torch.full((E,), I, dtype=torch.int64, device="cuda") + a2_strides = torch.full((E,), I // 2, dtype=torch.int64, device="cuda") + c2_strides = torch.full((E,), H, dtype=torch.int64, device="cuda") + + workspace = torch.empty( + (7182 * 1024), device="cuda", dtype=torch.uint8 + ) # Allocate sufficient workspace + # Pointer arrays (often filled by the kernel or a prep step, but needed as args) + a_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda") + b_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda") + out_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda") + a_scales_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda") + b_scales_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda") + expert_offsets = torch.empty((E + 1,), dtype=torch.int32, device="cuda") + problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device="cuda") + problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda") + + # --- Lambdas for Benchmarking --- + cutlass_lambda = lambda: cutlass_fused_experts( + x, + w1.transpose(1, 2), # Transposed + w2.transpose(1, 2), # Transposed + w1_scale.transpose(1, 2), + w2_scale.transpose(1, 2), + topk_weights, + topk_ids, + a1_strides, + c1_strides, + a2_strides, + c2_strides, + workspace, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + expert_offsets, + problem_sizes1, + problem_sizes2, + ) + + # Note: Triton expects non-transposed weights + triton_lambda = lambda: fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=False, # Use False for benchmarking to avoid side effects if run multiple times + activation="silu", # Assuming SiLU activation common in MoEs + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + block_shape=block_shape, + ) + + # --- Warmup --- + print("Warming up...") + for _ in range(10): + _ = cutlass_lambda() + _ = triton_lambda() + torch.cuda.synchronize() + + # --- Benchmarking --- + quantiles = [0.5, 0.2, 0.8] + print(f"Benchmarking Cutlass fused_experts...") + cutlass_ms, cutlass_min, cutlass_max = triton.testing.do_bench_cudagraph( + cutlass_lambda, rep=1000, quantiles=quantiles + ) + + print(f"Benchmarking Triton fused_experts...") + triton_ms, triton_min, triton_max = triton.testing.do_bench_cudagraph( + triton_lambda, rep=1000, quantiles=quantiles + ) + print( + f"Cutlass fused_experts time: {cutlass_ms:.3f} ms (median) [{cutlass_min:.3f} - {cutlass_max:.3f}]" + ) + print( + f"Triton fused_experts time: {triton_ms:.3f} ms (median) [{triton_min:.3f} - {triton_max:.3f}]" + ) + + # --- Correctness Check --- + if check: + print("Running correctness check...") + with torch.no_grad(): + # Run CUTLASS version (requires transposed weights) + y_cutlass = cutlass_fused_experts( + x, + w1.transpose(1, 2), # Transposed + w2.transpose(1, 2), # Transposed + w1_scale.transpose(1, 2), + w2_scale.transpose(1, 2), + topk_weights, + topk_ids, + a1_strides, + c1_strides, + a2_strides, + c2_strides, + workspace, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + expert_offsets, + problem_sizes1, + problem_sizes2, + ) + + # Run Triton version (requires original shape weights, use inplace=False) + y_triton = fused_experts( + x, + w1, # Original shape + w2, # Original shape + topk_weights, + topk_ids, + inplace=False, # Important: Use False to get output tensor + activation="silu", + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + block_shape=block_shape, + ) + + # Ensure outputs are same dtype for comparison + y_cutlass = y_cutlass.to(dtype) + y_triton = y_triton.to(dtype) + + abs_error = torch.abs(y_cutlass - y_triton) + rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2) + + max_abs_err = abs_error.max().item() + max_rel_err = rel_error.max().item() + + print("y_cutlass:", y_cutlass[:, :10]) + print("y_triton:", y_triton[:, :10]) + print(f"Max absolute error: {max_abs_err:.6f}") + print(f"Max relative error: {max_rel_err:.6f}") + + # Tolerance might need adjustment based on FP8 specifics and kernel differences + # FP8 comparisons often require higher tolerance than FP16/BF16 + assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}" + print("Correctness check passed.") + + +def main(tp_size=8, batch_sizes=[1, 4, 8, 16, 32, 64, 128, 256, 512], check=False): + model_config = get_model_config(tp_size) + print("Model Config:", model_config) + for batch_size in batch_sizes: + run_test(tp_size, batch_size, model_config, check) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tp-size", type=int, default=8, help="Tensor Parallel size") + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default + help="List of batch sizes to test", + ) + parser.add_argument("--check", action="store_true", help="Enable check mode") + args = parser.parse_args() + + print(f"Running benchmarks with TP size: {args.tp_size}") + print(f"Testing batch sizes: {args.batch_sizes}") + + main(tp_size=args.tp_size, batch_sizes=args.batch_sizes, check=args.check) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt old mode 100755 new mode 100644 index 5d52a1345..a2858d3ec --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -207,6 +207,7 @@ set(SOURCES "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu" + "csrc/moe/prepare_moe_input.cu" "csrc/speculative/eagle_utils.cu" "csrc/speculative/speculative_sampling.cu" "csrc/speculative/packbit.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 35bd7a1cb..649bf4297 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -151,11 +151,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.def( - "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " + "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor " + "a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " "stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor " - "expert_offsets) -> ()"); + "expert_offsets, Tensor workspace) -> ()"); m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm); - + m.def( + "prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor " + "input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()"); + m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input); /* * From csrc/speculative */ diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu index d85293dce..b51849234 100644 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -1,3 +1,5 @@ +#include +#include #include #include @@ -49,23 +51,16 @@ void launch_sm100_fp8_blockwise_scaled_group_mm( using ElementC = OutType; using ElementD = ElementC; using ElementAccumulator = float; - // Layout definitions using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = LayoutD; - // Alignment constraints static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - // Architecture definitions using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; - // For fp8 block scale. - // using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig; using LayoutSFA = - // decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, @@ -124,9 +119,8 @@ void launch_sm100_fp8_blockwise_scaled_group_mm( cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; - hw_info.sm_count = 1; - // Currently, we are only able to do broadcast on either all or none a_scales - // and on either all or none b_scales + // sm_count is the number of SMs on the current device, since we only support SM100 blackwell, so we set it to 148 + hw_info.sm_count = 148; typename GemmKernel::EpilogueArguments epilogue_args{ {}, nullptr, @@ -134,9 +128,7 @@ void launch_sm100_fp8_blockwise_scaled_group_mm( static_cast(out_ptrs.data_ptr()), static_cast(stride_c.data_ptr())}; - // Initialize problem_sizes_as_shapes correctly UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); - // Use prob_shape in the GEMM arguments typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, {num_experts, problem_sizes_as_shapes, nullptr}, @@ -144,21 +136,27 @@ void launch_sm100_fp8_blockwise_scaled_group_mm( epilogue_args, hw_info}; + at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); + auto can_implement_status = gemm_op.can_implement(args); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); - // Run the GEMM - auto status = gemm_op.initialize(args, workspace.data_ptr()); - + auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); - status = gemm_op.run(); + status = gemm_op.run(stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } template void sm100_fp8_blockwise_group_mm_dispatch_shape( torch::Tensor& output, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, @@ -169,11 +167,23 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets) { + const torch::Tensor& expert_offsets, + const torch::Tensor& workspace) { // Check the first matrix size to decide on the configuration // Assuming all matrices in the group have similar size characteristics // bool use_small_config = a[0].size(0) <= 128; - struct MMALargeConfig { + struct MmaConfig1 { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_128, _32, _128>; + using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using ScaleConfig = + cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + }; + struct MmaConfig2 { using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand @@ -184,35 +194,28 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; - - struct MMASmallConfig { + struct MmaConfig3 { using ElementA = cutlass::float_e4m3_t; - using MmaTileShape = Shape<_128, _16, _128>; + using MmaTileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; using ScaleConfig = - cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; + cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; int num_experts = (int)expert_offsets.size(0); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); - torch::Tensor a_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_ptrs = torch::empty(num_experts, options_int); - torch::Tensor out_ptrs = torch::empty(num_experts, options_int); - torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); - torch::Tensor workspace = torch::empty(100, options_int); torch::Tensor output_t = output.t(); torch::Tensor a_t = a.t(); torch::Tensor b_t = b.transpose(1, 2); torch::Tensor scales_a_t = scales_a.t(); torch::Tensor scales_b_t = scales_b.transpose(1, 2); - if (a.size(0) <= 512) { - run_get_group_gemm_starts( + if (a.size(0) <= 512 && a.size(1) >= 2048) { + run_get_group_gemm_starts( expert_offsets, a_ptrs, b_ptrs, @@ -229,7 +232,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( problem_sizes, problem_sizes_transpose, true); - launch_sm100_fp8_blockwise_scaled_group_mm( + launch_sm100_fp8_blockwise_scaled_group_mm( out_ptrs, a_ptrs, b_ptrs, @@ -244,8 +247,8 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( expert_offsets, workspace); output = output_t.t(); - } else { - run_get_group_gemm_starts( + } else if (a.size(0) > 512 && a.size(1) >= 2048) { + run_get_group_gemm_starts( expert_offsets, a_ptrs, b_ptrs, @@ -261,7 +264,38 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( layout_sfb, problem_sizes, problem_sizes_transpose); - launch_sm100_fp8_blockwise_scaled_group_mm( + launch_sm100_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); + } else { + run_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + output, + scales_a, + scales_b, + layout_sfa, + layout_sfb, + problem_sizes, + problem_sizes_transpose); + launch_sm100_fp8_blockwise_scaled_group_mm( out_ptrs, a_ptrs, b_ptrs, @@ -312,6 +346,11 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( */ void fp8_blockwise_scaled_grouped_mm( torch::Tensor& output, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, @@ -322,7 +361,8 @@ void fp8_blockwise_scaled_grouped_mm( const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets) { + const torch::Tensor& expert_offsets, + const torch::Tensor& workspace) { TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); TORCH_CHECK( @@ -342,6 +382,29 @@ void fp8_blockwise_scaled_grouped_mm( TORCH_CHECK(layout_sfb.scalar_type() == torch::kInt32, "layout_sfb must be int32"); TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32"); + TORCH_CHECK(output.dim() == 2, "output must be 2D tensor"); + TORCH_CHECK(a.dim() == 2, "a must be 2D tensor"); + TORCH_CHECK(b.dim() == 3, "b must be 3D tensor"); + TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor"); + TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor"); + TORCH_CHECK(stride_a.dim() == 1, "stride_a must be 1D tensor"); + TORCH_CHECK(stride_b.dim() == 1, "stride_b must be 1D tensor"); + TORCH_CHECK(stride_c.dim() == 1, "stride_c must be 1D tensor"); + TORCH_CHECK(layout_sfa.dim() == 2, "layout_sfa must be 1D tensor"); + TORCH_CHECK(layout_sfb.dim() == 2, "layout_sfb must be 1D tensor"); + TORCH_CHECK(a_ptrs.dim() == 1, "a_ptrs must be 1D tensor"); + TORCH_CHECK(b_ptrs.dim() == 1, "b_ptrs must be 1D tensor"); + TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor"); + TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor"); + TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor"); + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32"); + TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor"); + TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor"); + bool can_implement = false; auto sm_version = getSMVersion(); @@ -351,6 +414,11 @@ void fp8_blockwise_scaled_grouped_mm( if (output.scalar_type() == torch::kBFloat16) { sm100_fp8_blockwise_group_mm_dispatch_shape( output, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, a, b, scales_a, @@ -361,10 +429,16 @@ void fp8_blockwise_scaled_grouped_mm( layout_sfa, layout_sfb, problem_sizes, - expert_offsets); + expert_offsets, + workspace); } else { sm100_fp8_blockwise_group_mm_dispatch_shape( output, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, a, b, scales_a, @@ -375,7 +449,8 @@ void fp8_blockwise_scaled_grouped_mm( layout_sfa, layout_sfb, problem_sizes, - expert_offsets); + expert_offsets, + workspace); } can_implement = true; } diff --git a/sgl-kernel/csrc/moe/prepare_moe_input.cu b/sgl-kernel/csrc/moe/prepare_moe_input.cu new file mode 100644 index 000000000..5f3010301 --- /dev/null +++ b/sgl-kernel/csrc/moe/prepare_moe_input.cu @@ -0,0 +1,128 @@ +#include +#include +#include + +#include + +constexpr uint64_t THREADS_PER_EXPERT = 512; + +__global__ void compute_problem_sizes( + const int* __restrict__ topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const int topk_length, + const int n, + const int k) { + int expert_id = blockIdx.x; + + int occurrences = 0; + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + occurrences += (topk_ids[i] == expert_id); + } + atomicAdd(&atomic_buffer[expert_id], occurrences); + __syncthreads(); + + if (threadIdx.x == 0) { + int final_occurrences = atomic_buffer[expert_id]; + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } +} + +__global__ void compute_expert_offsets( + const int32_t* __restrict__ problem_sizes1, + int32_t* expert_offsets, + int32_t* atomic_buffer, + const int num_experts) { + int32_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + atomic_buffer[i] = tot_offset; + tot_offset += problem_sizes1[i * 3]; + expert_offsets[i + 1] = tot_offset; + } +} + +__global__ void compute_arg_sorts( + const int* __restrict__ topk_ids, + int32_t* input_permutation, + int32_t* output_permutation, + int32_t* atomic_buffer, + const int topk_length, + const int topk) { + int expert_id = blockIdx.x; + + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + if (topk_ids[i] == expert_id) { + int start = atomicAdd(&atomic_buffer[expert_id], 1); + input_permutation[start] = i / topk; + output_permutation[i] = start; + } + } +} + +void get_moe_prepare_input_caller( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + topk_ids.numel(), + n, + k); + compute_expert_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts); + compute_arg_sorts<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(input_permutation.data_ptr()), + static_cast(output_permutation.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + topk_ids.numel(), + topk_ids.size(1)); +} + +void prepare_moe_input( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k) { + TORCH_CHECK(topk_ids.dtype() == torch::kInt32); + get_moe_prepare_input_caller( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + input_permutation, + output_permutation, + num_experts, + n, + k); + return; +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index bf608456d..658f6950e 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -211,6 +211,11 @@ std::vector moe_fused_gate( void fp8_blockwise_scaled_grouped_mm( torch::Tensor& output, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, @@ -221,7 +226,19 @@ void fp8_blockwise_scaled_grouped_mm( const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets); + const torch::Tensor& expert_offsets, + const torch::Tensor& workspace); + +void prepare_moe_input( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k); /* * From csrc/speculative diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 0aaf09042..70b5cdc77 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -47,6 +47,7 @@ from sgl_kernel.moe import ( fp8_blockwise_scaled_grouped_mm, moe_align_block_size, moe_fused_gate, + prepare_moe_input, topk_softmax, ) from sgl_kernel.sampling import ( diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index f989fb8f7..e7b5eede0 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -64,6 +64,11 @@ def moe_fused_gate( def fp8_blockwise_scaled_grouped_mm( output, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, a, b, scales_a, @@ -75,9 +80,15 @@ def fp8_blockwise_scaled_grouped_mm( layout_sfb, problem_sizes, expert_offsets, + workspace, ): torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default( output, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, a, b, scales_a, @@ -89,4 +100,29 @@ def fp8_blockwise_scaled_grouped_mm( layout_sfb, problem_sizes, expert_offsets, + workspace, + ) + + +def prepare_moe_input( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + input_permutation, + output_permutation, + num_experts, + n, + k, +): + torch.ops.sgl_kernel.prepare_moe_input.default( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + input_permutation, + output_permutation, + num_experts, + n, + k, ) diff --git a/sgl-kernel/tests/test_fp8_blockwise_moe.py b/sgl-kernel/tests/test_fp8_blockwise_moe.py index 18493f007..40415a582 100755 --- a/sgl-kernel/tests/test_fp8_blockwise_moe.py +++ b/sgl-kernel/tests/test_fp8_blockwise_moe.py @@ -131,9 +131,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): c_strides = torch.full( (num_experts,), c_out.stride(0), device=device, dtype=torch.int64 ) + workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8) + a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) fp8_blockwise_scaled_grouped_mm( c_out, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, a_stack, b_stack, a_scale_stack, @@ -145,6 +156,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): layout_sfb, problem_sizes, expert_offsets[:-1], + workspace, ) for g in range(num_experts):