From d4a938417d2c310eae3bae19f1376a3fec142e07 Mon Sep 17 00:00:00 2001 From: chenxj Date: Tue, 2 Sep 2025 13:17:26 +0800 Subject: [PATCH] [feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118) Co-authored-by: yuhyao <827623970@qq.com> --- python/sglang/srt/configs/model_config.py | 3 +- .../sglang/srt/layers/moe/cutlass_w4a8_moe.py | 10 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 3 - .../srt/layers/moe/fused_moe_triton/layer.py | 7 +- .../sglang/srt/layers/quantization/w4afp8.py | 55 ++-- python/sglang/srt/models/deepseek_v2.py | 5 + python/sglang/test/test_cutlass_w4a8_moe.py | 33 ++- .../w4a8/w4a8_get_group_starts.cuh | 2 +- .../cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu | 266 ++++++++++++++---- .../cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh | 13 +- sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py | 14 +- 11 files changed, 291 insertions(+), 120 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 8fb00972e..caf1f2abc 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -405,9 +405,10 @@ class ModelConfig: # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) if quant_cfg is None: - # check if is modelopt model -- modelopt doesn't have corresponding field + # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main + # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main is_local = os.path.exists(self.model_path) modelopt_quant_config = {"quant_method": "modelopt"} if not is_local: diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 7a03511c4..8e4143e0e 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -91,18 +91,10 @@ def cutlass_w4a8_moe( assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert ( - w1_scale.shape[1] == w1_q.shape[2] * 2 / 512 - and w1_scale.shape[2] == w1_q.shape[1] * 4 - ), "W1 scale shape mismatch" - assert ( - w2_scale.shape[1] == w2_q.shape[2] * 2 / 512 - and w2_scale.shape[2] == w2_q.shape[1] * 4 - ), "W2 scale shape mismatch" assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch" assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch" - assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch" + assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch" assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch" num_experts = w1_q.size(0) m = a.size(0) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 175914560..a4c78c589 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -114,9 +114,6 @@ class EPMoE(FusedMoE): with_bias=with_bias, ) - self.start_expert_id = self.moe_ep_rank * self.num_local_experts - self.end_expert_id = self.start_expert_id + self.num_local_experts - 1 - self.intermediate_size = intermediate_size if isinstance(quant_config, Fp8Config): diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 7b3452525..b88c60d96 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module): self.moe_tp_rank = get_moe_tensor_parallel_rank() assert num_experts % self.moe_ep_size == 0 self.num_local_experts = num_experts // self.moe_ep_size + self.start_expert_id = self.moe_ep_rank * self.num_local_experts + self.end_expert_id = self.start_expert_id + self.num_local_experts - 1 if self.moe_ep_size > 1: # TODO(ch-wan): support shared experts fusion # Create a tensor of size num_experts filled with -1 @@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module): if ( "compressed" in self.quant_method.__class__.__name__.lower() - and param.data[expert_id] != 1 - and (param.data[expert_id] - loaded_weight).abs() > 1e-5 + or "w4afp8" in self.quant_config.get_name() + and (param.data[expert_id] != 1).any() + and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any() ): raise ValueError( "input_scales of w1 and w3 of a layer " diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index 9be54d05a..a1cdc6cba 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -1,12 +1,14 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter +from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size +from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, @@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig): from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.managers.schedule_batch import global_server_args_dict if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() return Fp8LinearMethod(self) - elif isinstance(layer, EPMoE): + elif isinstance(layer, FusedMoE): return W4AFp8MoEMethod(self) return None @@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig): return [] -class W4AFp8MoEMethod(FusedMoEMethodBase): +def interleave_scales(scales: torch.Tensor) -> torch.Tensor: + """Interleave scales in groups of 4 similar to TRT-LLM implementation.""" + s_shape = scales.shape + # Reshape to separate groups of 4 + alignment = 4 if s_shape[2] % 4 == 0 else 1 + scales_interleaved = scales.reshape( + s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment + ) + # Permute dimensions to interleave + scales_interleaved = scales_interleaved.permute(0, 2, 1, 3) + # Reshape back to original dimensions but with interleaved values + scales_interleaved = scales_interleaved.reshape( + s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment + ) + return scales_interleaved.contiguous() + +class W4AFp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: W4AFp8Config): self.quant_config = quant_config @@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): return - def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor: - """Interleave scales in groups of 4 similar to TRT-LLM implementation.""" - s_shape = scales.shape - # Reshape to separate groups of 4 - scales_interleaved = scales.reshape( - s_shape[0], s_shape[1], (s_shape[2] // 4), 4 - ) - # Permute dimensions to interleave - scales_interleaved = scales_interleaved.permute(0, 2, 1, 3) - # Reshape back to original dimensions but with interleaved values - scales_interleaved = scales_interleaved.reshape( - s_shape[0], s_shape[2] // 4, s_shape[1] * 4 - ) - return scales_interleaved.contiguous() - def process_weights_after_loading(self, layer: Module) -> None: dtype = torch.bfloat16 device = layer.w2_weight.device # Interleave w13_weight_scale (gate_up_proj) w13_weight_scale = layer.w13_weight_scale_inv.to(dtype) - w13_weight_scale = self._interleave_scales(w13_weight_scale) + w13_weight_scale = interleave_scales(w13_weight_scale) layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False) # Interleave w2_weight_scale (down_proj) w2_weight_scale = layer.w2_weight_scale_inv.to(dtype) - w2_weight_scale = self._interleave_scales(w2_weight_scale) + w2_weight_scale = interleave_scales(w2_weight_scale) layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False) # Process input scales @@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, _ = topk_output local_topk_ids = topk_ids - local_topk_ids = torch.where( - topk_ids == -1, - layer.num_experts, - topk_ids, - ) + if get_moe_expert_parallel_world_size() > 1: + local_topk_ids = torch.where( + topk_ids == -1, + layer.num_experts, + topk_ids, + ) output = cutlass_w4a8_moe( layer.start_expert_id, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 6058488a1..bceb60cfe 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2185,6 +2185,8 @@ class DeepseekV2ForCausalLM(nn.Module): disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization." elif get_moe_expert_parallel_world_size() > 1: disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism." + elif self.quant_config.get_name() == "w4afp8": + disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts." if disable_reason is not None: global_server_args_dict["disable_shared_experts_fusion"] = True @@ -2496,6 +2498,9 @@ class DeepseekV2ForCausalLM(nn.Module): ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, ) + # Params for special naming rules in mixed-precision models, for example: + # model.layers.xx.mlp.experts.xx.w1.input_scale. For details, + # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main. if self.quant_config and self.quant_config.get_name() == "w4afp8": expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( num_experts=self.config.n_routed_experts diff --git a/python/sglang/test/test_cutlass_w4a8_moe.py b/python/sglang/test/test_cutlass_w4a8_moe.py index 622941f00..6706fc962 100644 --- a/python/sglang/test/test_cutlass_w4a8_moe.py +++ b/python/sglang/test/test_cutlass_w4a8_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Literal, Optional import pytest import torch @@ -25,7 +25,7 @@ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Ten return packed_tensor.to(torch.int8) -def pack_interleave(num_experts, ref_weight, ref_scale): +def pack_interleave(num_experts, ref_weight, ref_scale, alignment=4): n, k = ref_weight.shape[1], ref_weight.shape[2] weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda() @@ -33,11 +33,16 @@ def pack_interleave(num_experts, ref_weight, ref_scale): w_q = w_q.contiguous() scale_interleaved = ref_scale.reshape( - ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4 + ref_scale.shape[0], + ref_scale.shape[1], + (ref_scale.shape[2] // alignment), + alignment, ) # [E, N, K/4, 4] scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4] scale_interleaved = scale_interleaved.reshape( - ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4 + ref_scale.shape[0], + ref_scale.shape[2] // alignment, + ref_scale.shape[1] * alignment, ) # [E, K/4, N*4] w_scale = scale_interleaved.contiguous() @@ -48,12 +53,17 @@ def pack_interleave(num_experts, ref_weight, ref_scale): @pytest.mark.parametrize("N", [2048]) @pytest.mark.parametrize("K", [7168]) @pytest.mark.parametrize("E", [256]) -@pytest.mark.parametrize("ep_size", [8]) +@pytest.mark.parametrize("tp_size", [8]) +@pytest.mark.parametrize("use_ep_moe", [True, False]) @pytest.mark.parametrize("topk", [8]) @pytest.mark.parametrize("group_size", [128]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): - local_e = E // ep_size +def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dtype): + if use_ep_moe: + local_e = E // tp_size + else: # tp mode + local_e = E + N = N // tp_size debug = False if debug: @@ -87,7 +97,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): ) w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1) - w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2) + if use_ep_moe: + w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2) + else: + w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2, 1) device = "cuda" a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64) @@ -265,7 +278,9 @@ def ref( gate, fc1 = fc1.chunk(2, dim=-1) fc1 = fc1 * torch.nn.functional.silu(gate) - act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn) + act = torch.clamp((fc1 / pre_quant_scale_2.float()), -448.0, 448.0).to( + torch.float8_e4m3fn + ) act = act.to(dtype) w2 = ref_weight_2[e_idx] diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh index f926202c0..8cd50c60c 100644 --- a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh @@ -31,7 +31,7 @@ __global__ void int4_fp8_get_group_gemm_starts( b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2; out_offsets[expert_id] = out_base_as_int + expert_offset * n; a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0); - b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id); + b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * k / 128 : expert_id); } #define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu index cffa171cc..bd63d2ee1 100644 --- a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu @@ -2,6 +2,8 @@ #include #include +#include + #include "cutlass/cutlass.h" #include "w4a8_grouped_mm_c3x.cuh" @@ -9,38 +11,60 @@ using namespace cute; namespace { -#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c +enum class Sched { PP, CO }; -#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c +template +struct SM90W4A8Config { + using KernelSchedule = std::conditional_t< + S == Sched::PP, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>; -#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \ - struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \ - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \ - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \ - using TileShape = cute::Shape, cute::Int, cute::Int>; \ - using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ - \ - using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm; \ - }; + using EpilogueSchedule = std::conditional_t< + S == Sched::PP, + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong, + cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>; -#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \ - struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \ - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \ - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ - using TileShape = cute::Shape, cute::Int, cute::Int>; \ - using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ - \ - using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm; \ - }; + using TileShape = cute::Shape, cute::Int, cute::Int>; + using ClusterShape = cute::Shape, cute::Int, cute::Int>; + using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm; +}; -GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1) -GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1) +template +using SM90_PP = SM90W4A8Config; -GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1) -GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1) -GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1) -GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1) -GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1) +template +using SM90_CO = SM90W4A8Config; + +template +inline void invoke_gemm( + torch::Tensor& d_tensors, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& d_strides, + torch::Tensor const& s_strides, + int64_t chunk_size) { + using GemmT = typename Config::Cutlass3xW4A8Gemm; + cutlass_w4a8_group_gemm_caller( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); +} void dispatch_w4a8_moe_mm_sm90( torch::Tensor& d_tensors, @@ -56,9 +80,6 @@ void dispatch_w4a8_moe_mm_sm90( torch::Tensor const& s_strides, int64_t chunk_size, int64_t topk) { - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; - uint32_t const m = a_tensors.size(0) / topk; uint32_t const n = d_tensors.size(1); uint32_t const k = a_tensors.size(1); @@ -66,8 +87,7 @@ void dispatch_w4a8_moe_mm_sm90( if (n == 4096 && k == 7168) { // group gemm 1 if (m <= 4) { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( + invoke_gemm>( d_tensors, a_tensors, b_tensors, @@ -81,8 +101,7 @@ void dispatch_w4a8_moe_mm_sm90( s_strides, chunk_size); } else if (m <= 16) { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( + invoke_gemm>( d_tensors, a_tensors, b_tensors, @@ -96,8 +115,7 @@ void dispatch_w4a8_moe_mm_sm90( s_strides, chunk_size); } else if (m <= 256) { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( + invoke_gemm>( d_tensors, a_tensors, b_tensors, @@ -111,8 +129,7 @@ void dispatch_w4a8_moe_mm_sm90( s_strides, chunk_size); } else if (m <= 1024) { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( + invoke_gemm>( d_tensors, a_tensors, b_tensors, @@ -126,8 +143,7 @@ void dispatch_w4a8_moe_mm_sm90( s_strides, chunk_size); } else { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( + invoke_gemm>( d_tensors, a_tensors, b_tensors, @@ -144,8 +160,7 @@ void dispatch_w4a8_moe_mm_sm90( } else if (n == 7168 && k == 2048) { // group gemm 2 if (m <= 8) { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( + invoke_gemm>( d_tensors, a_tensors, b_tensors, @@ -159,8 +174,7 @@ void dispatch_w4a8_moe_mm_sm90( s_strides, chunk_size); } else if (m <= 512) { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( + invoke_gemm>( d_tensors, a_tensors, b_tensors, @@ -174,8 +188,125 @@ void dispatch_w4a8_moe_mm_sm90( s_strides, chunk_size); } else { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } + } else if (n == 512 && k == 7168) { + // group gemm 1 for tp + if (m <= 4) { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else if (m <= 16) { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else if (m <= 256) { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else if (m <= 1024) { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } + } else if (n == 7168 && k == 256) { + // group gemm 2 for tp + if (m <= 8) { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else if (m <= 512) { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else { + invoke_gemm>( d_tensors, a_tensors, b_tensors, @@ -190,20 +321,35 @@ void dispatch_w4a8_moe_mm_sm90( chunk_size); } } else { - using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; - cutlass_w4a8_group_gemm_caller( - d_tensors, - a_tensors, - b_tensors, - a_scales, - b_scales, - expert_offsets, - problem_sizes, - a_strides, - b_strides, - d_strides, - s_strides, - chunk_size); + if (k % 512 == 0) { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } else { + invoke_gemm>( + d_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + d_strides, + s_strides, + chunk_size); + } } } diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh index 1252b245f..9bc45ab1c 100644 --- a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh @@ -41,9 +41,8 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type using QuantType = cutlass::int4b_t; // 4-bit integer type using ElementAccumulator = float; // Accumulator type using ElementScale = cutlass::bfloat16_t; // Scale type -using ElementScalePacked = cutlass::Array; -using ElementC = cutlass::half_t; // Default output type (FP16) -using ElementD = ElementC; // Default output type (FP16) +using ElementC = cutlass::half_t; // Default output type (FP16) +using ElementD = ElementC; // Default output type (FP16) using ProblemShape = cutlass::gemm::GroupProblemShape>; // Architecture-specific configurations @@ -73,6 +72,10 @@ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; template struct cutlass_3x_w4a8_group_gemm { + static constexpr int GroupSize = 128; + static constexpr int PackedScalesNum = get<2>(TileShape{}) / GroupSize; + using ElementScalePacked = cutlass::Array; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, @@ -184,8 +187,6 @@ void cutlass_w4a8_group_gemm_caller( TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups"); TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups"); TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension"); - TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512"); - TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N"); // Check tensor types TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type"); @@ -241,7 +242,7 @@ void cutlass_w4a8_group_gemm_caller( static_cast(b_strides.data_ptr()), static_cast(a_ptrs.data_ptr()), static_cast(a_strides.data_ptr()), - static_cast(b_scales_ptrs.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), static_cast(s_strides.data_ptr()), static_cast(chunk_size)}, {fusion_args, diff --git a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py index 506f8301a..4ad5d29f5 100644 --- a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py +++ b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py @@ -27,12 +27,18 @@ def pack_interleave(num_experts, ref_weight, ref_scale): w_q = weight.view((num_experts, n, k // 2)).view(torch.int8) w_q = w_q.contiguous() + alignment = 4 if k % 512 == 0 else 1 scale_interleaved = ref_scale.reshape( - ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4 + ref_scale.shape[0], + ref_scale.shape[1], + (ref_scale.shape[2] // alignment), + alignment, ) # [E, N, K/4, 4] scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4] scale_interleaved = scale_interleaved.reshape( - ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4 + ref_scale.shape[0], + ref_scale.shape[2] // alignment, + ref_scale.shape[1] * alignment, ) # [E, K/4, N*4] w_scale = scale_interleaved.contiguous() @@ -137,8 +143,8 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): reason="cutlass_w4a8_moe_mm is only supported on sm90", ) @pytest.mark.parametrize("batch_size", [2, 4, 8, 16]) -@pytest.mark.parametrize("k", [512, 1024]) -@pytest.mark.parametrize("n", [1024, 2048]) +@pytest.mark.parametrize("k", [256, 512, 1024]) +@pytest.mark.parametrize("n", [1024, 2048, 7168]) @pytest.mark.parametrize("num_experts", [2, 4, 6, 8]) def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): torch.manual_seed(0)