From e67276ecb30595b6564cc4a029131d166c0814e5 Mon Sep 17 00:00:00 2001 From: "tql.99" <33377527+TianQiLin666666@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:47:15 +0800 Subject: [PATCH] feat: support cutlass_moe_fp8 kernel for fusedmoe in sm90 (#8678) --- python/sglang/srt/layers/moe/cutlass_moe.py | 26 +++++++++++++++----- python/sglang/srt/layers/quantization/fp8.py | 6 ++--- python/sglang/srt/layers/utils.py | 9 +++++++ 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 3774afac2..6dadf0d0f 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.utils import is_cuda _is_cuda = is_cuda() @@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8( if is_cuda: from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8_hopper_moe_mn_major, sglang_per_token_group_quant_fp8, ) @@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8( n = w2_q.size(1) topk = topk_ids.size(1) - - a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128) - device = a_q.device + device = a.device a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) @@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8( k, ) - rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) - rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) + if is_sm100_supported(): + a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128) + rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) + rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) + else: + rep_a = shuffle_rows(a, a_map, (m * topk, k)) + rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major( + rep_a, expert_offsets, problem_sizes1, 128 + ) + w1_scale = w1_scale.contiguous() c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) @@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8( intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) silu_and_mul(c1, intermediate) - intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) + if is_sm100_supported(): + intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) + else: + intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major( + intermediate, expert_offsets, problem_sizes2, 128 + ) + w2_scale = w2_scale.contiguous() fp8_blockwise_scaled_grouped_mm( c2, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 1b0824051..17e1b7868 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -63,7 +63,7 @@ from sglang.srt.layers.quantization.utils import ( per_tensor_dequantize, requantize_with_max_scale, ) -from sglang.srt.layers.utils import is_sm100_supported +from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -619,7 +619,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): if ( get_bool_env_var("SGLANG_CUTLASS_MOE") and self.cutlass_fp8_supported - and is_sm100_supported() + and (is_sm100_supported() or is_sm90_supported()) ): self.ab_strides1 = torch.full( (num_experts,), @@ -1034,7 +1034,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): get_bool_env_var("SGLANG_CUTLASS_MOE") and self.cutlass_fp8_supported and self.block_quant - and is_sm100_supported() + and (is_sm100_supported() or is_sm90_supported()) ): from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 diff --git a/python/sglang/srt/layers/utils.py b/python/sglang/srt/layers/utils.py index f61b86293..ac0ddb65c 100644 --- a/python/sglang/srt/layers/utils.py +++ b/python/sglang/srt/layers/utils.py @@ -1,5 +1,6 @@ import logging import re +from functools import lru_cache import torch @@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity): return (input,) if self.return_tuple else input +@lru_cache(maxsize=1) def is_sm100_supported(device=None) -> bool: return (torch.cuda.get_device_capability(device)[0] == 10) and ( torch.version.cuda >= "12.8" ) + + +@lru_cache(maxsize=1) +def is_sm90_supported(device=None) -> bool: + return (torch.cuda.get_device_capability(device)[0] == 9) and ( + torch.version.cuda >= "12.3" + )