From 825432fce673185e3349e3cdfd45e29e2a929222 Mon Sep 17 00:00:00 2001 From: Jinwu <70835312+ayrnb@users.noreply.github.com> Date: Wed, 15 Oct 2025 11:10:53 +0800 Subject: [PATCH] [1/N]Support DeepSeek-R1 w4a8 normal deepep (#8247) Co-authored-by: Hank Han --- .../sglang/srt/layers/moe/cutlass_w4a8_moe.py | 196 ++++++++++++++++++ python/sglang/srt/layers/moe/ep_moe/layer.py | 23 +- .../srt/layers/moe/token_dispatcher/deepep.py | 14 +- python/sglang/srt/layers/moe/utils.py | 4 + .../sglang/srt/layers/quantization/w4afp8.py | 48 ++++- python/sglang/srt/server_args.py | 1 + test/srt/quant/test_w4a8_deepseek_v3.py | 55 +++++ 7 files changed, 334 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index e1507be18..2a84dedc4 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Cutlass W4A8 MoE kernel.""" +import logging from typing import Optional import torch @@ -11,6 +12,9 @@ from sgl_kernel import ( ) from sglang.srt.layers.moe.ep_moe.kernels import ( + deepep_permute_triton_kernel, + deepep_post_reorder_triton_kernel, + deepep_run_moe_deep_preprocess, post_reorder_triton_kernel_for_cutlass_moe, pre_reorder_triton_kernel_for_cutlass_moe, run_moe_ep_preproess, @@ -201,3 +205,195 @@ def cutlass_w4a8_moe( BLOCK_SIZE=512, ) return output + + +def cutlass_w4a8_moe_deepep_normal( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids_: torch.Tensor, + a_strides1: torch.Tensor, + b_strides1: torch.Tensor, + c_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides2: torch.Tensor, + c_strides2: torch.Tensor, + s_strides13: torch.Tensor, + s_strides2: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a w4a8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of int4-quantized expert weights. + Shape: [num_experts, N * 2, K // 2] + (the weights are passed transposed and int4-packed) + - w2_q (torch.Tensor): The second set of int4-quantized expert weights. + Shape: [num_experts, K, N // 2] + (the weights are passed transposed and int4-packed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts, K // 512, N * 8] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts, N // 512, K * 4] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - a_strides1 (torch.Tensor): The input strides of the first grouped gemm. + - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - a_strides2 (torch.Tensor): The input strides of the second grouped gemm. + - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm. + - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [1, K] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [1, N] + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. + + Returns: + - torch.Tensor: The fp8 output tensor after applying the MoE layer. + """ + assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" + assert w1_q.dtype == torch.int8 + assert w2_q.dtype == torch.int8 + assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1" + assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" + + assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch" + assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch" + assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch" + assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch" + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(2) * 2 # w1_q is transposed and packed + n = w2_q.size(2) * 2 # w2_q is transposed and packed + topk = topk_ids_.size(1) + + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(2) * 2 + n = w2_q.size(2) * 2 + topk = topk_ids_.size(1) + device = a.device + + reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess( + topk_ids_, num_experts + ) + num_total_tokens = reorder_topk_ids.numel() + gateup_input_pre_reorder = torch.empty( + (int(num_total_tokens), a.shape[1]), + device=device, + dtype=a.dtype, + ) + deepep_permute_triton_kernel[(a.shape[0],)]( + a, + gateup_input_pre_reorder, + src2dst, + topk_ids_.to(torch.int64), + None, + topk, + a.shape[1], + BLOCK_SIZE=512, + ) + gateup_input = torch.empty( + gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device + ) + sgl_per_tensor_quant_fp8( + gateup_input_pre_reorder, gateup_input, a1_scale.float(), True + ) + del gateup_input_pre_reorder + local_topk_ids = topk_ids_ + local_topk_ids = ( + torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32) + ).contiguous() + + a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) + get_cutlass_w4a8_moe_mm_data( + local_topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + num_experts, + n, + k, + ) + c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16) + c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16) + + cutlass_w4a8_moe_mm( + c1, + gateup_input, + w1_q, + a1_scale.float(), + w1_scale, + expert_offsets[:-1], + problem_sizes1, + a_strides1, + b_strides1, + c_strides1, + s_strides13, + 128, + topk, + ) + intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16) + silu_and_mul(c1, intermediate) + + intermediate_q = torch.empty( + intermediate.shape, dtype=torch.float8_e4m3fn, device=device + ) + sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True) + + cutlass_w4a8_moe_mm( + c2, + intermediate_q, + w2_q, + a2_scale.float(), + w2_scale, + expert_offsets[:-1], + problem_sizes2, + a_strides2, + b_strides2, + c_strides2, + s_strides2, + 128, + topk, + ) + num_tokens = src2dst.shape[0] // topk + output = torch.empty( + (num_tokens, c2.shape[1]), + device=c2.device, + dtype=torch.bfloat16, + ) + deepep_post_reorder_triton_kernel[(num_tokens,)]( + c2, + output, + src2dst, + topk_ids_, + topk_weights, + topk, + c2.shape[1], + BLOCK_SIZE=512, + ) + + return output diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 3f68ad563..af82d54a4 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -29,6 +29,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( CUTEDSL_MOE_NVFP4_DISPATCH, ModelOptNvFp4FusedMoEMethod, ) +from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu @@ -96,6 +97,11 @@ class DeepEPMoE(FusedMoE): self.use_block_quant = getattr(self.quant_method, "block_quant", False) self.use_fp8_w8a8 = True self.fp8_dtype = torch.float8_e4m3fn + self.use_w4afp8 = False + elif isinstance(quant_config, W4AFp8Config): + self.use_w4afp8 = True + self.use_fp8_w8a8 = False + self.use_block_quant = False else: self.use_fp8_w8a8 = False self.use_block_quant = False @@ -142,7 +148,7 @@ class DeepEPMoE(FusedMoE): self.w13_weight, ( self.w13_weight_scale_inv - if self.use_block_quant + if self.use_block_quant or self.use_w4afp8 else self.w13_weight_scale ), ) @@ -150,7 +156,7 @@ class DeepEPMoE(FusedMoE): self.w2_weight, ( self.w2_weight_scale_inv - if self.use_block_quant + if self.use_block_quant or self.use_w4afp8 else self.w2_weight_scale ), ) @@ -210,6 +216,8 @@ class DeepEPMoE(FusedMoE): assert DispatchOutputChecker.format_is_deepep(dispatch_output) return self.forward_npu(dispatch_output) if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): + if self.use_w4afp8: + return self.forward_cutlass_w4afp8(dispatch_output) assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_contiguous(dispatch_output) elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): @@ -438,6 +446,17 @@ class DeepEPMoE(FusedMoE): ) return output + def forward_cutlass_w4afp8( + self, + dispatch_output: DeepEPNormalOutput, + ): + assert self.moe_runner_config.activation == "silu" + assert isinstance(self.quant_method, W4AFp8MoEMethod) + return self.quant_method.apply_deepep_normal( + layer=self, + dispatch_output=dispatch_output, + ) + def forward_deepgemm_masked( self, dispatch_output: DeepEPLLOutput, diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index abd0b9e82..618c4cf9e 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -14,7 +14,12 @@ from sglang.srt.layers.moe.token_dispatcher.base import ( DispatchOutput, DispatchOutputFormat, ) -from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled +from sglang.srt.layers.moe.utils import ( + DeepEPMode, + get_deepep_config, + get_moe_runner_backend, + is_tbo_enabled, +) from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.utils import ( get_bool_env_var, @@ -340,7 +345,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): topk_weights: torch.Tensor, ): topk_idx = topk_idx.to(torch.int64) - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and not get_moe_runner_backend().is_cutlass() + ): # TODO hard code 128 block quant,use fp8 communication hidden_states = sglang_per_token_group_quant_fp8( hidden_states, @@ -386,7 +394,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): async_finish=self.async_finish, allocate_on_comm_stream=previous_event is not None, ) - # FIXME: `handle` should be transmitted with tokens from dispatch to combine. # However, doing this would incur an unknown synchronization error, but keeping # `handle` as a member variable works. @@ -412,7 +419,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1, config=DeepEPConfig.get_instance().normal_dispatch_config, ) - get_global_expert_distribution_recorder().on_deepep_dispatch_normal( num_recv_tokens_per_expert, num_tokens_per_rank=num_tokens_per_rank, diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index f71192236..ac743dd86 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -55,6 +55,7 @@ class MoeRunnerBackend(Enum): FLASHINFER_CUTLASS = "flashinfer_cutlass" FLASHINFER_MXFP4 = "flashinfer_mxfp4" FLASHINFER_CUTEDSL = "flashinfer_cutedsl" + CUTLASS = "cutlass" def is_auto(self): return self == MoeRunnerBackend.AUTO @@ -80,6 +81,9 @@ class MoeRunnerBackend(Enum): def is_flashinfer_mxfp4(self): return self == MoeRunnerBackend.FLASHINFER_MXFP4 + def is_cutlass(self): + return self == MoeRunnerBackend.CUTLASS + class DeepEPMode(Enum): diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index fb85f0b31..e97de07d7 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch from torch.nn import Module @@ -21,8 +21,10 @@ from sglang.srt.utils import is_npu, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe import MoeRunnerConfig + from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE from sglang.srt.layers.moe.token_dispatcher import ( CombineInput, + DeepEPNormalOutput, StandardDispatchOutput, ) @@ -326,3 +328,47 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): if self.moe_runner_config.routed_scaling_factor is not None: output *= self.moe_runner_config.routed_scaling_factor return StandardCombineInput(hidden_states=output) + + def apply_deepep_normal( + self, + layer: DeepEPMoE, + dispatch_output: DeepEPNormalOutput, + ) -> torch.Tensor: + from sglang.srt.layers.moe.cutlass_w4a8_moe import ( + cutlass_w4a8_moe_deepep_normal, + ) + + hidden_states, topk_idx, topk_weights = ( + dispatch_output.hidden_states, + dispatch_output.topk_idx, + dispatch_output.topk_weights, + ) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + num_tokens = hidden_states.shape[0] + if num_tokens > 0: + return cutlass_w4a8_moe_deepep_normal( + hidden_states, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale_inv, + layer.w2_weight_scale_inv, + topk_weights, + topk_idx, + self.a_strides1, + self.b_strides1, + self.c_strides1, + self.a_strides2, + self.b_strides2, + self.c_strides2, + self.s_strides13, + self.s_strides2, + self.expert_offsets, + self.problem_sizes1, + self.problem_sizes2, + layer.w13_input_scale, + layer.w2_input_scale, + ) + else: + return hidden_states diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8a43567ff..8d179b2c7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -137,6 +137,7 @@ MOE_RUNNER_BACKEND_CHOICES = [ "flashinfer_cutlass", "flashinfer_mxfp4", "flashinfer_cutedsl", + "cutlass", ] diff --git a/test/srt/quant/test_w4a8_deepseek_v3.py b/test/srt/quant/test_w4a8_deepseek_v3.py index eb813bd70..30e022796 100644 --- a/test/srt/quant/test_w4a8_deepseek_v3.py +++ b/test/srt/quant/test_w4a8_deepseek_v3.py @@ -118,5 +118,60 @@ class TestDeepseekV3W4Afp8Mtp(CustomTestCase): self.assertGreater(avg_spec_accept_length, 2.9) +class TestDeepseekV3W4Afp8DeepepNormal(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = try_cached_model(DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST) + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--tp", + "8", + "--trust-remote-code", + "--ep-size", + "8", + "--cuda-graph-bs", + "256", + "--disable-radix-cache", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "normal", + "--dp", + "8", + "--enable-dp-attention", + "--moe-runner-backend", + "cutlass", + ] + if not is_in_amd_ci(): + other_args += ["--mem-frac", "0.7"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k( + self, + ): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Eval accuracy of GSM8K: {metrics=}") + + self.assertGreater(metrics["accuracy"], 0.92) + + if __name__ == "__main__": unittest.main()