From 9a30914e94fa80ab2647e9f0fcdcbf4d9c0aec89 Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Mon, 13 Oct 2025 11:19:21 +0800 Subject: [PATCH] [sgl-kernel][1/N]Support Expert Specialization Grouped GEMM (#11432) Co-authored-by: luoyuan.luo Co-authored-by: PGFLMG <1106310035@qq.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> --- sgl-kernel/CMakeLists.txt | 2 + .../bench_es_fp8_blockwise_grouped_gemm.py | 336 ++++++++++++++++++ sgl-kernel/csrc/common_extension.cc | 8 + .../expert_specialization/es_fp8_blockwise.cu | 126 +++++++ .../es_fp8_blockwise_functor.cuh | 268 ++++++++++++++ .../es_fp8_blockwise_launcher.cuh | 312 ++++++++++++++++ .../es_fp8_blockwise_traits.cuh | 174 +++++++++ sgl-kernel/include/sgl_kernel_ops.h | 15 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + .../python/sgl_kernel/expert_specilization.py | 27 ++ sgl-kernel/tests/test_es_fp8_blockwise_moe.py | 204 +++++++++++ 11 files changed, 1473 insertions(+) create mode 100644 sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py create mode 100644 sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu create mode 100644 sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_functor.cuh create mode 100644 sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh create mode 100644 sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_traits.cuh create mode 100644 sgl-kernel/python/sgl_kernel/expert_specilization.py create mode 100644 sgl-kernel/tests/test_es_fp8_blockwise_moe.py diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index b7ba690d5..7133ad652 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -323,6 +323,8 @@ set(SOURCES "csrc/speculative/packbit.cu" "csrc/speculative/speculative_sampling.cu" + "csrc/expert_specialization/es_fp8_blockwise.cu" + "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" diff --git a/sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py b/sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py new file mode 100644 index 000000000..1cc06d4b5 --- /dev/null +++ b/sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py @@ -0,0 +1,336 @@ +import argparse +import random +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import torch +from sgl_kernel import ( + es_fp8_blockwise_scaled_grouped_mm, + fp8_blockwise_scaled_grouped_mm, +) + +random.seed(28) + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (128 - (n % 128)) % 128 + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( + x_view.size(0), x_view.size(2) + ) + + +def create_unbalanced_expert_token_distribution(max_num_experts): + ratios = [random.random() for _ in range(max_num_experts)] + + def convert_to_tokens(ratio: float): + if ratio <= 0.7: + return random.randint(1, 32) + elif ratio > 0.7 and ratio <= 0.85: + return random.randint(32, 64) + elif ratio > 0.85 and ratio <= 0.95: + return random.randint(64, 128) + elif ratio > 0.95: + return random.randint(128, 1024) + else: + return 128 + + group_ms = [convert_to_tokens(ratio) for ratio in ratios] + return group_ms + + +group_ms = create_unbalanced_expert_token_distribution(8192) +# group_ms = [128 for _ in range(8192)] +# group_ms = [128 if i % 2 == 0 else 64 for i in range(8192)] + + +def bench_es( + n: int, + k: int, + num_groups: int, + num_warmup: int, + num_run: int, +) -> Tuple[float, int]: + device = "cuda" + alignment = 128 + n_g = ceil_div(n, alignment) * alignment + k_g = ceil_div(k, alignment) * alignment + out_dtype = torch.bfloat16 + + expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32) + problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32) + + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + if False: + print("Token Distributtion: ", group_ms[0:num_groups]) + print("Token Count: ", sum(group_ms[0:num_groups])) + for g in range(num_groups): + m_g = group_ms[g] + expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) + + a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device)) + b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t()) + a_tensors.append(a_g) + b_tensors.append(b_g) + a_scales_tensors.append(a_scale) + b_scales_tensors.append(b_scale) + + a_stack = torch.empty( + (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn + ) + b_stack = torch.empty( + (num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn + ) + + for g in range(num_groups): + a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g] + b_stack[g] = b_tensors[g].t() + b_stack = b_stack.transpose(1, 2) + + a_scale_stack = torch.empty( + (expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32 + ) + b_scale_stack = torch.empty( + (num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32 + ) + + for g in range(num_groups): + a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g] + b_scale_stack[g] = b_scales_tensors[g].t() + b_scale_stack = b_scale_stack.transpose(1, 2) + + c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) + a_strides = torch.full( + (num_groups,), a_stack.stride(0), device=device, dtype=torch.int64 + ) + d_strides = torch.full( + (num_groups,), c_out.stride(0), device=device, dtype=torch.int64 + ) + + def run_cutlass(): + es_fp8_blockwise_scaled_grouped_mm( + c_out, + a_stack, + b_stack, + a_scale_stack, + b_scale_stack, + a_strides, + a_strides, + d_strides, + problem_sizes, + expert_offsets[:-1], + ) + + run_cutlass() + # warmup + for _ in range(num_warmup): + run_cutlass() + torch.cuda.synchronize() + + # run + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_run): + run_cutlass() + end_event.record() + end_event.synchronize() + torch.cuda.synchronize() + avg = start_event.elapsed_time(end_event) / num_run * 1000 # us + + return avg, expert_offsets[-1] + + +def bench_sgl( + n: int, + k: int, + num_groups: int, + num_warmup: int, + num_run: int, +) -> Tuple[float, int]: + device = "cuda" + alignment = 128 + n_g = ceil_div(n, alignment) * alignment + k_g = ceil_div(k, alignment) * alignment + out_dtype = torch.bfloat16 + + expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32) + problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32) + layout_sfa = torch.zeros((num_groups, 5), device=device, dtype=torch.int32) + layout_sfb = torch.zeros((num_groups, 5), device=device, dtype=torch.int32) + + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + + for g in range(num_groups): + m_g = group_ms[g] + expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) + + a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device)) + b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t()) + a_tensors.append(a_g) + b_tensors.append(b_g) + a_scales_tensors.append(a_scale) + b_scales_tensors.append(b_scale) + + a_stack = torch.empty( + (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn + ) + b_stack = torch.empty( + (num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn + ) + + for g in range(num_groups): + a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g] + b_stack[g] = b_tensors[g].t() + b_stack = b_stack.transpose(1, 2) + + a_scale_stack = torch.empty( + (expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32 + ) + b_scale_stack = torch.empty( + (num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32 + ) + + for g in range(num_groups): + a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g] + b_scale_stack[g] = b_scales_tensors[g].t() + b_scale_stack = b_scale_stack.transpose(1, 2) + + c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) + a_strides = torch.full( + (num_groups,), a_stack.stride(0), device=device, dtype=torch.int64 + ) + c_strides = torch.full( + (num_groups,), c_out.stride(0), device=device, dtype=torch.int64 + ) + workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8) + a_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + b_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + out_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + a_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + b_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + + def run_cutlass(): + fp8_blockwise_scaled_grouped_mm( + c_out, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a_stack, + b_stack, + a_scale_stack, + b_scale_stack, + a_strides, + a_strides, + c_strides, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets[:-1], + workspace, + ) + + # warmup + for _ in range(num_warmup): + run_cutlass() + torch.cuda.synchronize() + + # run + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_run): + run_cutlass() + end_event.record() + end_event.synchronize() + torch.cuda.synchronize() + avg = start_event.elapsed_time(end_event) / num_run * 1000 # us + + return avg, expert_offsets[-1] + + +benchmark_kernels = {"es": bench_es, "sgl-kernel": bench_sgl} + + +@dataclass +class ShapeArg: + n: int + k: int + num_groups: int + + +def benchmark_one_shape( + shape_args: List[ShapeArg], + num_warmup: int, + num_run: int, +): + for shape in shape_args: + print(f"\nBenchmark: n={shape.n}, k={shape.k}, num_groups={shape.num_groups}") + for kernel_name, kernel_func in benchmark_kernels.items(): + average_time, m = kernel_func( + shape.n, + shape.k, + shape.num_groups, + num_warmup, + num_run, + ) + print(f"{kernel_name}: {average_time} us") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--num-warmup", type=int, default=3) + parser.add_argument("--num-run", type=int, default=20) + shape_args = [ + # Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8 + ShapeArg(n=512, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, down, chunk_size = 4096, TP = 8 + ShapeArg(n=7168, k=256, num_groups=256), + # Prefill, Qwen3-235B-A22B-FP8, gateup, TP = 4 + ShapeArg(n=768, k=4096, num_groups=128), + # Prefill, Qwen3-235B-A22B-FP8, down, TP = 4 + ShapeArg(n=4096, k=384, num_groups=128), + # Decode, DeepSeek-R1, gateup, bs = 128, EP = 8 + ShapeArg(n=4096, k=7168, num_groups=32), + # Decode, DeepSeek-R1, gateup, bs = 256, EP = 16 + ShapeArg(n=4096, k=7168, num_groups=16), + ] + args = parser.parse_args() + benchmark_one_shape(shape_args, args.num_warmup, args.num_run) + + +if __name__ == "__main__": + main() diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 44a0d1e57..8fce6a276 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -531,6 +531,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "bool silu_activation," "int pad_slot_id) -> ()"); m.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); + + /* + * From csrc/expert_sepcialization + */ + m.def( + "es_fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " + "stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets) -> ()"); + m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu new file mode 100644 index 000000000..f7e4b61ef --- /dev/null +++ b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu @@ -0,0 +1,126 @@ +#include + +#include + +#include "es_fp8_blockwise_launcher.cuh" + +/** + * @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs, + * with per-block scaling. + * + * This function dispatches to hardware-specific implementations (e.g., SM100 FP8) + * to compute: + * C_i = scale_a[i] * A_i * scale_b[i] * B_i + * for each expert group `i`, using input `problem_sizes` and `expert_offsets` + * to describe the individual matrix dimensions and their offsets. + * + * Input tensors A and B must be quantized to 8-bit formats and dequantized before multiplication. + * The output tensor is written with bfloat16 or half precision. + * + * @param output Output tensor (must be of type bfloat16 or half). + * @param a Input tensor A (must be kFloat8_e4m3fn). + * @param b Input tensor B (must be kFloat8_e4m3fn). + * @param scales_a Scaling factors for tensor A, float32 per expert group. + * @param scales_b Scaling factors for tensor B, float32 per expert group. + * @param stride_a Stride information for tensor A (int32). + * @param stride_b Stride information for tensor B (int32). + * @param stride_c Stride information for output tensor C (int32). + * @param problem_sizes 2D int32 tensor of shape (num_experts, 3), specifying (M, N, K) + * for each grouped matrix multiplication problem. + * @param expert_offsets 1D int32 tensor of size (num_experts), used to index into + * the grouped input tensors for dispatch. + */ +void es_fp8_blockwise_scaled_grouped_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_d, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets) { +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32"); + TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, "a must be kFloat8_e4m3fn"); + TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, "b must be kFloat8_e4m3fn"); + TORCH_CHECK( + output.scalar_type() == torch::kBFloat16 || output.scalar_type() == torch::kHalf, + "output must be bfloat16 or half"); + + int num_experts = (int)problem_sizes.size(0); + torch::TensorOptions options_int64 = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + torch::TensorOptions options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(a.device()); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int64); + torch::Tensor a_ptrs = torch::empty(num_experts, options_int64); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int64); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int64); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int64); + + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32); + + torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32); + torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32); + torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32); + expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + layout_sfa, + layout_sfb, + lm_problem_sizes, + mm_problem_sizes, + hm_problem_sizes, + output, + a, + b, + scales_a, + scales_b, + problem_sizes, + expert_offsets); + if (output.dtype() == torch::kBFloat16) { + expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_d, + layout_sfa, + layout_sfb, + lm_problem_sizes, + mm_problem_sizes, + hm_problem_sizes); + } else if (output.dtype() == torch::kFloat16) { + expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_d, + layout_sfa, + layout_sfb, + lm_problem_sizes, + mm_problem_sizes, + hm_problem_sizes); + } else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +#else + TORCH_CHECK_NOT_IMPLEMENTED( + can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version); +#endif +} diff --git a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_functor.cuh b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_functor.cuh new file mode 100644 index 000000000..c4af33447 --- /dev/null +++ b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_functor.cuh @@ -0,0 +1,268 @@ +#pragma once +#include + +#include + +#include "cute/tensor.hpp" +#include "es_fp8_blockwise_traits.cuh" + +namespace expert_specialization { + +using namespace cute; + +template +struct Fp8BlockwiseGroupedGemmOffsetFunctor { + // Input + int* expert_offsets{nullptr}; + // Base pointers + ElementAB* a_base{nullptr}; + ElementAB* b_base{nullptr}; + ElementD* out_base{nullptr}; + ElementSF* a_scales_base{nullptr}; + ElementSF* b_scales_base{nullptr}; + + // Output + // Pointer Array for A/B + ElementAB** a_offsets{nullptr}; + ElementAB** b_offsets{nullptr}; + ElementSF** a_scales_offsets{nullptr}; + ElementSF** b_scales_offsets{nullptr}; + ElementD** out_offsets{nullptr}; + + Fp8BlockwiseGroupedGemmOffsetFunctor() = default; + Fp8BlockwiseGroupedGemmOffsetFunctor( + int* _expert_offsets, + ElementAB* _a_base, + ElementAB* _b_base, + ElementD* _out_base, + ElementSF* _a_scales_base, + ElementSF* _b_scales_base, + ElementAB** _a_offsets, + ElementAB** _b_offsets, + ElementSF** _a_scales_offsets, + ElementSF** _b_scales_offsets, + ElementD** _out_offsets) + : expert_offsets(_expert_offsets), + a_base(_a_base), + b_base(_b_base), + out_base(_out_base), + a_scales_base(_a_scales_base), + b_scales_base(_b_scales_base), + a_offsets(_a_offsets), + b_offsets(_b_offsets), + a_scales_offsets(_a_scales_offsets), + b_scales_offsets(_b_scales_offsets), + out_offsets(_out_offsets) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t a_stride = 0; + int64_t b_stride = 0; + int64_t a_scale_stride = 0; + int64_t b_scale_stride = 0; + + a_stride = expert_offset * k; + b_stride = expert_id * k * n; + a_scale_stride = expert_offset * k / 128; + b_scale_stride = expert_id * k * n / 128 / 128; + + a_offsets[expert_id] = a_base + a_stride; + b_offsets[expert_id] = b_base + b_stride; + a_scales_offsets[expert_id] = a_scales_base + a_scale_stride; + b_scales_offsets[expert_id] = b_scales_base + b_scale_stride; + out_offsets[expert_id] = out_base + expert_offset * n; + } +}; + +template +struct Fp8BlockwiseGroupedGemmSFLayoutFunctor { + using ScaleConfig = typename PerfConfig::ScaleConfig; + using LayoutSFA = typename PerfConfig::LayoutSFA; + using LayoutSFB = typename PerfConfig::LayoutSFB; + LayoutSFA* layout_sfa_base{nullptr}; + LayoutSFB* layout_sfb_base{nullptr}; + + Fp8BlockwiseGroupedGemmSFLayoutFunctor() = default; + Fp8BlockwiseGroupedGemmSFLayoutFunctor(LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base) + : layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id; + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + } +}; + +// [Unused]: Specialization for Swap A/B +template <> +struct Fp8BlockwiseGroupedGemmSFLayoutFunctor { + using ScaleConfig = typename PerfConfigLowMH20::ScaleConfig; + using LayoutSFA = typename PerfConfigLowMH20::LayoutSFA; + using LayoutSFB = typename PerfConfigLowMH20::LayoutSFB; + LayoutSFA* layout_sfa_base{nullptr}; + LayoutSFB* layout_sfb_base{nullptr}; + + Fp8BlockwiseGroupedGemmSFLayoutFunctor() = default; + Fp8BlockwiseGroupedGemmSFLayoutFunctor(LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base) + : layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id; + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(n, m, k, 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(n, m, k, 1)); + } +}; + +template +struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor; + +template <> +struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor { + int* problem_sizes{nullptr}; + + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default; + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + if (m <= 48) { + // Swap A/B + problem_sizes[expert_id * 3 + 0] = n; + problem_sizes[expert_id * 3 + 1] = m; + problem_sizes[expert_id * 3 + 2] = k; + } else { + problem_sizes[expert_id * 3 + 0] = 0; + problem_sizes[expert_id * 3 + 1] = 0; + problem_sizes[expert_id * 3 + 2] = 0; + } + } +}; + +template <> +struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor { + int* problem_sizes{nullptr}; + + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default; + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + if (m <= 32) { + // Swap A/B + problem_sizes[expert_id * 3 + 0] = n; + problem_sizes[expert_id * 3 + 1] = m; + problem_sizes[expert_id * 3 + 2] = k; + } else { + problem_sizes[expert_id * 3 + 0] = 0; + problem_sizes[expert_id * 3 + 1] = 0; + problem_sizes[expert_id * 3 + 2] = 0; + } + } +}; + +template <> +struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor { + int* problem_sizes{nullptr}; + + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default; + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + if (m > 48 && m <= 96) { + problem_sizes[expert_id * 3 + 0] = m; + problem_sizes[expert_id * 3 + 1] = n; + problem_sizes[expert_id * 3 + 2] = k; + } else { + problem_sizes[expert_id * 3 + 0] = 0; + problem_sizes[expert_id * 3 + 1] = 0; + problem_sizes[expert_id * 3 + 2] = 0; + } + } +}; + +template <> +struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor { + int* problem_sizes{nullptr}; + + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default; + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + if (m > 32 && m <= 64) { + problem_sizes[expert_id * 3 + 0] = n; + problem_sizes[expert_id * 3 + 1] = m; + problem_sizes[expert_id * 3 + 2] = k; + } else { + problem_sizes[expert_id * 3 + 0] = 0; + problem_sizes[expert_id * 3 + 1] = 0; + problem_sizes[expert_id * 3 + 2] = 0; + } + } +}; + +template <> +struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor { + int* problem_sizes{nullptr}; + + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default; + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + if (m > 96) { + problem_sizes[expert_id * 3 + 0] = m; + problem_sizes[expert_id * 3 + 1] = n; + problem_sizes[expert_id * 3 + 2] = k; + } else { + problem_sizes[expert_id * 3 + 0] = 0; + problem_sizes[expert_id * 3 + 1] = 0; + problem_sizes[expert_id * 3 + 2] = 0; + } + } +}; + +template <> +struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor { + int* problem_sizes{nullptr}; + + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor() = default; + Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor(int* _problem_sizes) : problem_sizes(_problem_sizes) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + if (m > 64) { + problem_sizes[expert_id * 3 + 0] = m; + problem_sizes[expert_id * 3 + 1] = n; + problem_sizes[expert_id * 3 + 2] = k; + } else { + problem_sizes[expert_id * 3 + 0] = 0; + problem_sizes[expert_id * 3 + 1] = 0; + problem_sizes[expert_id * 3 + 2] = 0; + } + } +}; + +template < + typename OffsetFunctor, + typename ScaleLayoutFunctor, + typename LowMProblemSizeFilterFunctor, + typename MiddleMProblemSizeFilterFunctor, + typename HighMProblemSizeFilterFunctor> +__global__ void groupedGemmPreComputeKernel( + int* problem_sizes, + OffsetFunctor offset_functor, + ScaleLayoutFunctor sf_functor, + LowMProblemSizeFilterFunctor lm_psf_functor, + MiddleMProblemSizeFilterFunctor mm_psf_functor, + HighMProblemSizeFilterFunctor hm_psf_functor) { + int64_t expert_id = static_cast(threadIdx.x); + int m = problem_sizes[expert_id * 3 + 0]; + int n = problem_sizes[expert_id * 3 + 1]; + int k = problem_sizes[expert_id * 3 + 2]; + + offset_functor(expert_id, m, n, k); + sf_functor(expert_id, m, n, k); + lm_psf_functor(expert_id, m, n, k); + mm_psf_functor(expert_id, m, n, k); + hm_psf_functor(expert_id, m, n, k); +} + +} // namespace expert_specialization diff --git a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh new file mode 100644 index 000000000..3b0f4a8a2 --- /dev/null +++ b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh @@ -0,0 +1,312 @@ +#pragma once +#include +#include +#include + +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "es_fp8_blockwise_functor.cuh" + +namespace expert_specialization { + +using namespace cute; + +void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( + // Output + torch::Tensor& out_ptrs, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, + torch::Tensor& layout_sfa, + torch::Tensor& layout_sfb, + torch::Tensor& lm_problem_sizes, + torch::Tensor& mm_problem_sizes, + torch::Tensor& hm_problem_sizes, + // Input + torch::Tensor& out_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& expert_offsets) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + const std::string H20_device_type_str("NVIDIA H20"); + bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str; + + // Creat Scale Factor Layout Functor + using LayoutSFA = typename PerfConfigMiddleMH20::LayoutSFA; + using LayoutSFB = typename PerfConfigMiddleMH20::LayoutSFB; + struct Fp8BlockwiseGroupedGemmSFLayoutFunctor sf_layout( + reinterpret_cast(layout_sfa.data_ptr()), reinterpret_cast(layout_sfb.data_ptr())); + + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + // Dispatch + if (out_tensors.dtype() == torch::kBFloat16) { + struct Fp8BlockwiseGroupedGemmOffsetFunctor of( + static_cast(expert_offsets.data_ptr()), + static_cast(a_tensors.data_ptr()), + static_cast(b_tensors.data_ptr()), + static_cast(out_tensors.data_ptr()), + static_cast(a_scales.data_ptr()), + static_cast(b_scales.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + static_cast(out_ptrs.data_ptr())); + if (!is_h20_device) { + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( + static_cast(lm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( + static_cast(mm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( + static_cast(hm_problem_sizes.data_ptr())); + groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( + static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); + } else { + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( + static_cast(lm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( + static_cast(mm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( + static_cast(hm_problem_sizes.data_ptr())); + groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( + static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); + } + } else if (out_tensors.dtype() == torch::kFloat16) { + struct Fp8BlockwiseGroupedGemmOffsetFunctor of( + static_cast(expert_offsets.data_ptr()), + static_cast(a_tensors.data_ptr()), + static_cast(b_tensors.data_ptr()), + static_cast(out_tensors.data_ptr()), + static_cast(a_scales.data_ptr()), + static_cast(b_scales.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + static_cast(out_ptrs.data_ptr())); + if (!is_h20_device) { + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( + static_cast(lm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( + static_cast(mm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( + static_cast(hm_problem_sizes.data_ptr())); + groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( + static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); + } else { + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( + static_cast(lm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( + static_cast(mm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( + static_cast(hm_problem_sizes.data_ptr())); + groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( + static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); + } + } else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +template +void launch_sm90_fp8_blockwise_scaled_group_mm( + torch::Tensor& out_ptrs, + const torch::Tensor& a_ptrs, + const torch::Tensor& b_ptrs, + const torch::Tensor& a_scales_ptrs, + const torch::Tensor& b_scales_ptrs, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_d, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes) { + using ElementA = typename GemmTraits::ElementA; + using StrideA = typename GemmTraits::StrideA; + using ElementB = typename GemmTraits::ElementB; + using StrideB = typename GemmTraits::StrideB; + using ElementAccumulator = typename GemmTraits::ElementAccumulator; + using LayoutSFA = typename GemmTraits::LayoutSFA; + using LayoutSFB = typename GemmTraits::LayoutSFB; + using ElementD = typename GemmTraits::ElementD; + using StrideD = typename GemmTraits::StrideD; + using UnderlyingProblemShape = typename GemmTraits::ProblemShape::UnderlyingProblemShape; + using Gemm = typename GemmTraits::Gemm; + using GemmKernel = typename GemmTraits::GemmKernel; + + int num_experts = (int)problem_sizes.size(0); + Gemm gemm_op; + + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(stride_a.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(stride_b.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = c10::cuda::current_device(); + hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, nullptr, nullptr, static_cast(out_ptrs.data_ptr()), static_cast(stride_d.data_ptr())}; + + UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info}; + + at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); + + torch::TensorOptions options_uint8 = torch::TensorOptions().dtype(torch::kUInt8).device(out_ptrs.device()); + size_t workspace_size = gemm_op.get_workspace_size(args); + torch::Tensor workspace = torch::empty(workspace_size, options_uint8); + + auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(stream, nullptr, true); // Enable PDL + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +template +void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( + torch::Tensor& out_ptrs, + const torch::Tensor& a_ptrs, + const torch::Tensor& b_ptrs, + const torch::Tensor& a_scales_ptrs, + const torch::Tensor& b_scales_ptrs, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_d, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& lm_problem_sizes, + const torch::Tensor& mm_problem_sizes, + const torch::Tensor& hm_problem_sizes) { + using LowMGemmH20Traits = + ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits; + using LowMGemmHx00Traits = + ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits; + using MiddleMGemmH20Traits = + ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits; + using MiddleMGemmHx00Traits = ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits< + OutType, + cutlass::layout::ColumnMajor, + PerfConfigMiddleMHx00>; + using HighMGemmH20Traits = + ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits; + using HighMGemmHx00Traits = + ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits; + + const std::string H20_device_type_str("NVIDIA H20"); + bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str; + + if (!is_h20_device) { + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + b_ptrs, + a_ptrs, + b_scales_ptrs, + a_scales_ptrs, + stride_b, + stride_a, + stride_d, + layout_sfb, + layout_sfa, + lm_problem_sizes); + } else { + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + b_ptrs, + a_ptrs, + b_scales_ptrs, + a_scales_ptrs, + stride_b, + stride_a, + stride_d, + layout_sfb, + layout_sfa, + lm_problem_sizes); + } + + if (!is_h20_device) { + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + b_ptrs, + a_ptrs, + b_scales_ptrs, + a_scales_ptrs, + stride_b, + stride_a, + stride_d, + layout_sfb, + layout_sfa, + mm_problem_sizes); + } else { + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_d, + layout_sfa, + layout_sfb, + mm_problem_sizes); + } + + if (!is_h20_device) { + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_d, + layout_sfa, + layout_sfb, + hm_problem_sizes); + } else { + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_d, + layout_sfa, + layout_sfb, + hm_problem_sizes); + } +} + +} // namespace expert_specialization diff --git a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_traits.cuh b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_traits.cuh new file mode 100644 index 000000000..3bc7d929a --- /dev/null +++ b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_traits.cuh @@ -0,0 +1,174 @@ +#pragma once + +// Misc +#include "cute/tensor.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/detail/blockwise_scale_layout.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_size.h" + +// Collective Builder +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +// Integration +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +namespace expert_specialization { + +using namespace cute; + +struct PerfConfigLowMH20 { + // Swap A/B + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_128, _32, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); +}; + +struct PerfConfigLowMHx00 { + // Swap A/B + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_256, _32, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); +}; + +struct PerfConfigMiddleMH20 { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); +}; + +struct PerfConfigMiddleMHx00 { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_256, _64, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); +}; + +struct PerfConfigHighMH20 { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); +}; + +struct PerfConfigHighMHx00 { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); +}; + +template +struct ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = OutType; + using ElementAccumulator = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = LayoutD; + using LayoutSFA = typename PerfConfig::LayoutSFA; + using LayoutSFB = typename PerfConfig::LayoutSFB; + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using CustomEVTIdentity = // acc + cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion:: + Sm90Compute, + cutlass::epilogue::fusion::Sm90AccFetch>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + typename PerfConfig::MmaTileShape, + typename PerfConfig::ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use void to avoid load Matrix C + LayoutC*, + AlignmentC, + ElementD, + LayoutD*, + AlignmentD, + typename PerfConfig::EpilogueSchedule, + CustomEVTIdentity>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + typename PerfConfig::MmaTileShape, + typename PerfConfig::ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename PerfConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; +}; + +} // namespace expert_specialization diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 040084f3e..255c0a311 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -821,3 +821,18 @@ void causal_conv1d_fwd( const std::optional& has_initial_state, bool silu_activation, int64_t pad_slot_id); + +/* + * From csrc/expert_specialization + */ +void es_fp8_blockwise_scaled_grouped_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_d, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index c2c200269..6091ac41c 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -244,6 +244,7 @@ from sgl_kernel.elementwise import ( rmsnorm, silu_and_mul, ) +from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm from sgl_kernel.fused_moe import fused_marlin_moe from sgl_kernel.gemm import ( awq_dequantize, diff --git a/sgl-kernel/python/sgl_kernel/expert_specilization.py b/sgl-kernel/python/sgl_kernel/expert_specilization.py new file mode 100644 index 000000000..81b411cf5 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/expert_specilization.py @@ -0,0 +1,27 @@ +import torch + + +def es_fp8_blockwise_scaled_grouped_mm( + output, + a, + b, + scales_a, + scales_b, + stride_a, + stride_b, + stride_d, + problem_sizes, + expert_offsets, +): + torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default( + output, + a, + b, + scales_a, + scales_b, + stride_a, + stride_b, + stride_d, + problem_sizes, + expert_offsets, + ) diff --git a/sgl-kernel/tests/test_es_fp8_blockwise_moe.py b/sgl-kernel/tests/test_es_fp8_blockwise_moe.py new file mode 100644 index 000000000..9118facaa --- /dev/null +++ b/sgl-kernel/tests/test_es_fp8_blockwise_moe.py @@ -0,0 +1,204 @@ +import random +from typing import Tuple + +import pytest +import torch +from sgl_kernel import es_fp8_blockwise_scaled_grouped_mm + + +def cdiv(a: int, b: int) -> int: + return -(a // -b) + + +def scale_shape(shape, group_shape): + return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (128 - (n % 128)) % 128 + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( + x_view.size(0), x_view.size(2) + ) + + +def baseline_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], +) -> torch.Tensor: + + def group_broadcast(t, shape): + for i, s in enumerate(shape): + if t.shape[i] != s and t.shape[i] != 1: + assert s % t.shape[i] == 0 + t = ( + t.unsqueeze(i + 1) + .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :]) + .flatten(i, i + 1) + ) + return t + + scale_a = group_broadcast(scale_a, a.shape) + scale_b = group_broadcast(scale_b, b.shape) + + return torch.mm( + (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) + ).to(out_dtype) + + +def is_sm100_supported(device=None) -> bool: + return (torch.cuda.get_device_capability(device)[0] == 10) and ( + torch.version.cuda >= "12.8" + ) + + +def is_sm90_supported(device=None) -> bool: + return (torch.cuda.get_device_capability(device)[0] == 9) and ( + torch.version.cuda >= "12.3" + ) + + +@pytest.mark.skipif( + not (is_sm100_supported() or is_sm90_supported()), + reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90", +) +@pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128]) +@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) +def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): + device = "cuda" + alignment = 128 + n_g = random.randint(1, 64) * alignment + k_g = random.randint(1, 64) * alignment + + expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32) + problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + baseline_tensors = [] + + for g in range(num_experts): + m_g = random.randint(1, 256) + expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) + + a = torch.randn((m_g, k_g), device=device, dtype=out_dtype) # (M, K):(K, 1) + b = torch.randn((n_g, k_g), device=device, dtype=out_dtype).t() # (K, N):(1, K) + + a_g, a_scale = per_token_cast_to_fp8( + a + ) # ag -- (M, K):(K, 1), a_scale() -- (M, k):(k, 1) + b_g, b_scale = per_block_cast_to_fp8( + b + ) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1) + a_tensors.append(a_g) + b_tensors.append(b_g) + a_scales_tensors.append(a_scale) + b_scales_tensors.append(b_scale) + + baseline = torch.mm(a, b) + baseline_tensors.append(baseline) + a_stack = torch.empty( + (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn + ) + b_stack = torch.empty( + (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn + ) + a_scale_stack = torch.empty( + (expert_offsets[-1], (k_g // 128)), device=device, dtype=torch.float32 + ) + b_scale_stack = torch.empty( + (num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32 + ) + + for g in range(num_experts): + # Matrix A is Row-Major. + a_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_tensors[ + g + ] # a_stack[expert_offsets[g] : expert_offsets[g + 1], :] -- (M, K):(K, 1) + b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1) + + # We need K-Major scale factor + a_scale_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_scales_tensors[ + g + ] + b_scale_stack[g] = b_scales_tensors[ + g + ].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later + b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major. + b_scale_stack = b_scale_stack.transpose(1, 2) + + c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) + a_strides = torch.full( + (num_experts,), a_stack.stride(0), device=device, dtype=torch.int64 + ) + d_strides = torch.full( + (num_experts,), c_out.stride(0), device=device, dtype=torch.int64 + ) + + es_fp8_blockwise_scaled_grouped_mm( + c_out, + a_stack, + b_stack, + a_scale_stack, + b_scale_stack, + a_strides, + a_strides, + d_strides, + problem_sizes, + expert_offsets[:-1], + ) + + for g in range(num_experts): + baseline = baseline_tensors[g] + actual = c_out[expert_offsets[g] : expert_offsets[g + 1]] + diff = calc_diff(actual, baseline) + assert diff < 0.001 + print( + f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK" + ) + + +if __name__ == "__main__": + pytest.main([__file__])