From 2c4feaf30875728d278748221711e88e83ebc2ca Mon Sep 17 00:00:00 2001 From: ayrnb <70835312+ayrnb@users.noreply.github.com> Date: Thu, 3 Jul 2025 14:27:03 +0800 Subject: [PATCH] Add CUTLASS FP8 Blockscale MoE kernel for Hopper architecture (#7278) Co-authored-by: HydraQYH Co-authored-by: TianQiLin666666 <1834987979@qq.com> --- .../bench_fp8_blockwise_group_gemm.py | 330 ++++++++++++++++++ .../csrc/moe/fp8_blockwise_moe_kernel.cu | 245 ++++++++++++- sgl-kernel/tests/test_fp8_blockwise_moe.py | 12 +- 3 files changed, 578 insertions(+), 9 deletions(-) create mode 100644 sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py mode change 100755 => 100644 sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu diff --git a/sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py b/sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py new file mode 100644 index 000000000..2a0a8e410 --- /dev/null +++ b/sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py @@ -0,0 +1,330 @@ +import argparse +import random +from dataclasses import dataclass +from typing import List, Tuple + +import deep_gemm +import torch +from sgl_kernel import fp8_blockwise_scaled_grouped_mm + + +def get_m_alignment_for_contiguous_layout(): + return 128 + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (128 - (n % 128)) % 128 + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( + x_view.size(0), x_view.size(2) + ) + + +def construct_contiguous_grouped( + num_groups: int, expected_m_per_group: int, k: int, n: int +) -> Tuple[ + int, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + alignment = get_m_alignment_for_contiguous_layout() + group_ms = [int(expected_m_per_group) for _ in range(num_groups)] + m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) + + x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16) + m_indices = torch.empty(m, device="cuda", dtype=torch.int32) + out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + + start = 0 + for i, group_m in enumerate(group_ms): + actual_end = start + group_m + aligned_end = start + ceil_div(group_m, alignment) * alignment + m_indices[start:actual_end] = i + m_indices[actual_end:aligned_end] = -1 + start = aligned_end + + assert m % 4 == 0, f"TMA alignment error: {m}" + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = ( + torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty( + (num_groups, ceil_div(n, 128), k // 128), device="cuda", dtype=torch.float + ), + ) + for i in range(num_groups): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + return m, x_fp8, y_fp8, m_indices, out + + +def bench_deepgemm( + expected_m_per_group: int, + n: int, + k: int, + num_groups: int, + num_warmup: int, + num_run: int, +) -> Tuple[float, int]: + # construct tensors + m, x_fp8, y_fp8, m_indices, out = construct_contiguous_grouped( + num_groups, expected_m_per_group, k, n + ) + + def run_deepgemm(): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + x_fp8, y_fp8, out, m_indices + ) + + # warmup + for _ in range(num_warmup): + run_deepgemm() + torch.cuda.synchronize() + + # run + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + latencies: list[float] = [] + start_event.record() + for _ in range(num_run): + run_deepgemm() + end_event.record() + end_event.synchronize() + torch.cuda.synchronize() + avg = start_event.elapsed_time(end_event) / num_run * 1000 # us + + return avg, m + + +def bench_cutlass( + expected_m_per_group: int, + n: int, + k: int, + num_groups: int, + num_warmup: int, + num_run: int, +) -> Tuple[float, int]: + device = "cuda" + alignment = 16 + n_g = ceil_div(n, alignment) * alignment + k_g = ceil_div(k, alignment) * alignment + out_dtype = torch.bfloat16 + + expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32) + problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32) + layout_sfa = torch.zeros((num_groups, 5), device=device, dtype=torch.int32) + layout_sfb = torch.zeros((num_groups, 5), device=device, dtype=torch.int32) + + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + + # TODO(@TianQiLin666666): Unique group_ms in all bench function + group_ms = [ + alignment * ceil_div(int(expected_m_per_group), alignment) + for _ in range(num_groups) + ] + for g in range(num_groups): + m_g = group_ms[g] + expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) + + a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device)) + b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t()) + a_tensors.append(a_g) + b_tensors.append(b_g) + a_scales_tensors.append(a_scale) + b_scales_tensors.append(b_scale) + + a_stack = torch.empty( + (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn + ) + b_stack = torch.empty( + (num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn + ) + + for g in range(num_groups): + a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g] + b_stack[g] = b_tensors[g].t() + b_stack = b_stack.transpose(1, 2) + + a_scale_stack = torch.empty( + (expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32 + ) + b_scale_stack = torch.empty( + (num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32 + ) + + for g in range(num_groups): + a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g] + b_scale_stack[g] = b_scales_tensors[g].t() + b_scale_stack = b_scale_stack.transpose(1, 2) + + c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) + a_strides = torch.full( + (num_groups,), a_stack.stride(0), device=device, dtype=torch.int64 + ) + c_strides = torch.full( + (num_groups,), c_out.stride(0), device=device, dtype=torch.int64 + ) + workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8) + a_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + b_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + out_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + a_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + b_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64) + + def run_cutlass(): + fp8_blockwise_scaled_grouped_mm( + c_out, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a_stack, + b_stack, + a_scale_stack, + b_scale_stack, + a_strides, + a_strides, + c_strides, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets[:-1], + workspace, + ) + + # warmup + for _ in range(num_warmup): + run_cutlass() + torch.cuda.synchronize() + + # run + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_run): + run_cutlass() + end_event.record() + end_event.synchronize() + torch.cuda.synchronize() + avg = start_event.elapsed_time(end_event) / num_run * 1000 # us + + return avg, expert_offsets[-1] + + +def bench_sglang_triton( + expected_m_per_group: int, + n: int, + k: int, + num_groups: int, + num_warmup: int, + num_run: int, +) -> Tuple[float, int]: + pass + + +benchmark_kernels = { + "deepgemm": bench_deepgemm, + "cutlass": bench_cutlass, + # "triton": bench_sglang_triton, +} + + +@dataclass +class ShapeArg: + expected_m_per_group: int + n: int + k: int + num_groups: int + + +def benchmark_one_shape( + shape_args: List[ShapeArg], + num_warmup: int, + num_run: int, +): + for shape in shape_args: + print( + f"\nBenchmark: expected_m_per_group={shape.expected_m_per_group}, n={shape.n}, k={shape.k}, num_groups={shape.num_groups}" + ) + for kernel_name, kernel_func in benchmark_kernels.items(): + average_time, m = kernel_func( + shape.expected_m_per_group, + shape.n, + shape.k, + shape.num_groups, + num_warmup, + num_run, + ) + print(f"{kernel_name}: {average_time} us") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--num-warmup", type=int, default=3) + parser.add_argument("--num-run", type=int, default=10) + shape_args = [ + # Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8 + ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8 + ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16 + ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16 + ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256), + # Decode, DeepSeek-R1, gateup, bs = 32, TP = 8 + ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256), + # Decode, DeepSeek-R1, gateup, bs = 64, TP = 16 + ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256), + # Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8 + ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32), + # Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16 + ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16), + # Decode, DeepSeek-R1, gateup, bs = 128, EP = 8 + ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32), + # Decode, DeepSeek-R1, gateup, bs = 256, EP = 16 + ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16), + # Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4 + ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128), + # Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4 + ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128), + # Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4 + ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128), + # Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4 + ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128), + ] + args = parser.parse_args() + benchmark_one_shape(shape_args, args.num_warmup, args.num_run) + + +if __name__ == "__main__": + main() diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu old mode 100755 new mode 100644 index e3e170e47..dc022bcc9 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -30,6 +30,126 @@ using namespace cute; using ProblemShape = cutlass::gemm::GroupProblemShape>; + +template +void launch_sm90_fp8_blockwise_scaled_group_mm( + torch::Tensor& out_ptrs, + const torch::Tensor& a_ptrs, + const torch::Tensor& b_ptrs, + const torch::Tensor& a_scales_ptrs, + const torch::Tensor& b_scales_ptrs, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& workspace) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = OutType; + using ElementAccumulator = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = LayoutD; + + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using FusionOperation = cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use void to avoid load Matrix C + LayoutC*, + AlignmentC, + ElementD, + LayoutC*, + AlignmentC, + typename ScheduleConfig::EpilogueSchedule, + FusionOperation>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + int num_experts = (int)expert_offsets.size(0); + Gemm gemm_op; + + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(stride_a.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(stride_b.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = c10::cuda::current_device(); + hw_info.sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, + nullptr, + static_cast(stride_c.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(stride_c.data_ptr())}; + + UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info}; + + at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); + + auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + template void launch_sm100_fp8_blockwise_scaled_group_mm( torch::Tensor& out_ptrs, @@ -312,6 +432,74 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( } } +template +void sm90_fp8_blockwise_group_mm_dispatch_shape( + torch::Tensor& output, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& workspace) { + struct MmaConfig { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; + + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + }; + + int num_experts = (int)expert_offsets.size(0); + torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); + + run_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + output, + scales_a, + scales_b, + layout_sfa, + layout_sfb, + problem_sizes, + problem_sizes_transpose); + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); +} + /** * @brief Performs blockwise grouped matrix multiplication on FP8 quantized inputs, * with per-block scaling. @@ -397,11 +585,6 @@ void fp8_blockwise_scaled_grouped_mm( TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor"); TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor"); TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor"); - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); - TORCH_CHECK( - problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32"); TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor"); TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor"); @@ -455,7 +638,57 @@ void fp8_blockwise_scaled_grouped_mm( can_implement = true; } #endif +#endif + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + if (sm_version == 90 && a.size(1) > 256) { + if (output.scalar_type() == torch::kBFloat16) { + sm90_fp8_blockwise_group_mm_dispatch_shape( + output, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + scales_a, + scales_b, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); + } else { + sm90_fp8_blockwise_group_mm_dispatch_shape( + output, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + scales_a, + scales_b, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); + } + can_implement = true; + } #endif TORCH_CHECK_NOT_IMPLEMENTED( - can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); + can_implement, + "No implemented fp8_blockwise_scaled_mm for current compute capability or K size: ", + sm_version, + a.size(1)); } diff --git a/sgl-kernel/tests/test_fp8_blockwise_moe.py b/sgl-kernel/tests/test_fp8_blockwise_moe.py index 40415a582..a0249f496 100755 --- a/sgl-kernel/tests/test_fp8_blockwise_moe.py +++ b/sgl-kernel/tests/test_fp8_blockwise_moe.py @@ -53,9 +53,15 @@ def is_sm100_supported(device=None) -> bool: ) +def is_sm90_supported(device=None) -> bool: + return (torch.cuda.get_device_capability(device)[0] == 9) and ( + torch.version.cuda >= "12.8" + ) + + @pytest.mark.skipif( - not is_sm100_supported(), - reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100", + not (is_sm100_supported() or is_sm90_supported()), + reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90", ) @pytest.mark.parametrize("num_experts", [8, 16]) @pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) @@ -162,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): for g in range(num_experts): baseline = baseline_tensors[g] actual = c_out[expert_offsets[g] : expert_offsets[g + 1]] - torch.testing.assert_close(actual, baseline, rtol=1e-2, atol=5e-4) + torch.testing.assert_close(actual, baseline, rtol=1e-2, atol=1e-3) print(f"num_experts={num_experts}, out_dtype={out_dtype}: OK")