From 0b9dfba78700a331e097682d4505791450412cba Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 2 Oct 2025 18:02:19 +0800 Subject: [PATCH] Support dispatch low latency (#10263) Co-authored-by: Kaixi Hou <4001424+kaixih@users.noreply.github.com> --- python/sglang/srt/layers/moe/ep_moe/layer.py | 11 ++++ .../srt/layers/moe/flashinfer_cutedsl_moe.py | 63 +++++++++++-------- .../srt/layers/moe/token_dispatcher/deepep.py | 22 ++++++- .../srt/layers/quantization/modelopt_quant.py | 12 +++- python/sglang/srt/models/deepseek_v2.py | 1 + 5 files changed, 80 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 0bd49600e..287bc00fc 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -31,6 +31,10 @@ from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, sglang_per_token_group_quant_fp8, ) +from sglang.srt.layers.quantization.modelopt_quant import ( + CUTEDSL_MOE_NVFP4_DISPATCH, + ModelOptNvFp4FusedMoEMethod, +) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.offloader import get_offloader @@ -453,6 +457,13 @@ class DeepEPMoE(EPMoE): topk_idx=topk_idx, topk_weights=topk_weights, forward_batch=forward_batch, + input_global_scale=( + self.w13_input_scale_quant + if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) + and self.quant_method.enable_flashinfer_cutedsl_moe + and CUTEDSL_MOE_NVFP4_DISPATCH + else None + ), ) def moe_impl(self, dispatch_output: DispatchOutput): diff --git a/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py b/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py index c8813ff6f..f96361ecb 100644 --- a/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +++ b/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import torch from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked @@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str: def flashinfer_cutedsl_moe_masked( - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], input_global_scale: torch.Tensor, w1: torch.Tensor, w1_blockscale: torch.Tensor, @@ -36,7 +36,9 @@ def flashinfer_cutedsl_moe_masked( kernels. Args: - hidden_states (torch.Tensor): [num_experts, m, k], bf16 + hidden_states: Either of the following case + * torch.Tensor: [num_experts, m, k], bf16 + * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn input_global_scale (torch.Tensor): (l,) w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 w1_blockscale (torch.Tensor): blockscale factors, e4m3, @@ -48,13 +50,10 @@ def flashinfer_cutedsl_moe_masked( masked_m (torch.Tensor): Masked dimension indices Notes: - - Assumes max(masked_m) <= m. + - Assumes max(masked_m) == m. """ # === Assertions on dtypes === - assert ( - input_global_scale.dtype == torch.float32 - ), f"input_global_scale must be float32, got {input_global_scale.dtype}" assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}" assert ( w1_blockscale.dtype == torch.float8_e4m3fn @@ -75,7 +74,31 @@ def flashinfer_cutedsl_moe_masked( # === Assertions on shapes === n = w2.shape[-1] * 2 # intermediate dimension - num_experts, m, k = hidden_states.shape + + if isinstance(hidden_states, tuple): + assert ( + input_global_scale is None + ), "input_global_scale is needed when input needs quant" + + a_q = hidden_states[0].view(torch.uint8) + a_q_sf = hidden_states[1].view(torch.float8_e4m3fn) + m, k_by_2, num_experts = a_q.shape + k = k_by_2 * 2 + else: + num_experts, m, k = hidden_states.shape + + assert ( + input_global_scale.dtype == torch.float32 + ), f"input_global_scale must be float32, got {input_global_scale.dtype}" + assert input_global_scale.shape == ( + num_experts, + ), f"input_global_scale must be (l,), got {input_global_scale.shape}" + + a_q, a_q_sf = scaled_fp4_grouped_quant( + hidden_states, + input_global_scale, + masked_m, + ) assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" assert ( @@ -85,10 +108,6 @@ def flashinfer_cutedsl_moe_masked( k, n // 2, ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}" - - assert input_global_scale.shape == ( - num_experts, - ), f"input_global_scale must be (l,), got {input_global_scale.shape}" assert w1_alpha.shape == ( num_experts, ), f"w1_alpha must be (l,), got {w1_alpha.shape}" @@ -99,27 +118,21 @@ def flashinfer_cutedsl_moe_masked( num_experts, ), f"w2_alpha must be (l,), got {w2_alpha.shape}" - aq, aq_sf = scaled_fp4_grouped_quant( - hidden_states, - input_global_scale, - masked_m, - ) + # TODO(kaixih@nvidia): dtype should be based on inputs. gateup_output = torch.empty( - (num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device + (num_experts, m, n * 2), dtype=torch.bfloat16, device=a_q.device ) gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel sf_vec_size = 16 - assert aq_sf.dtype == torch.float8_e4m3fn - assert aq.dtype == torch.uint8 + assert a_q_sf.dtype == torch.float8_e4m3fn + assert a_q.dtype == torch.uint8 ab_dtype = "float4_e2m1fn" sf_dtype = "float8_e4m3fn" - - c_dtype = get_cute_dtype(hidden_states) + c_dtype = "bfloat16" # Gemm1 - grouped_gemm_nt_masked( - (aq, aq_sf), + (a_q, a_q_sf), (w1.permute(1, 2, 0), w1_blockscale), gateup_output, masked_m, @@ -139,7 +152,7 @@ def flashinfer_cutedsl_moe_masked( ) # Gemm2 - out = torch.empty_like(hidden_states) + out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device) out = out.permute(1, 2, 0) # requirement of kernel grouped_gemm_nt_masked( (diq, diq_sf), diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 598f51331..da09f022b 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -296,6 +296,7 @@ class _DeepEPDispatcherImplBase: def dispatch_a( self, hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], topk_idx: torch.Tensor, topk_weights: torch.Tensor, ): @@ -329,6 +330,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): def dispatch_a( self, hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], topk_idx: torch.Tensor, topk_weights: torch.Tensor, ): @@ -505,6 +507,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def dispatch_a( self, hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], topk_idx: torch.Tensor, topk_weights: torch.Tensor, ): @@ -516,9 +519,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ) // self.num_experts hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, + input_global_scale, topk_idx, - # TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341 - use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"), ) return ( hidden_states, @@ -558,9 +560,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def _dispatch_core( self, hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], topk_idx: torch.Tensor, - use_fp8: bool = False, ): + use_nvfp4 = use_fp8 = False + if input_global_scale is not None: + use_nvfp4 = True + elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"): + use_fp8 = True + buffer = self._get_buffer() packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( buffer.low_latency_dispatch( @@ -569,6 +577,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): self.num_max_dispatch_tokens_per_rank, self.num_experts, use_fp8=use_fp8, + **(dict(use_nvfp4=True) if use_nvfp4 else dict()), + **( + dict(x_global_scale=input_global_scale) + if input_global_scale is not None + else dict() + ), async_finish=not self.return_recv_hook, return_recv_hook=self.return_recv_hook, round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM @@ -682,6 +696,7 @@ class DeepEPDispatcher(BaseDispatcher): def dispatch_a( self, hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], topk_idx: torch.Tensor, topk_weights: torch.Tensor, forward_batch: ForwardBatch, @@ -689,6 +704,7 @@ class DeepEPDispatcher(BaseDispatcher): self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) inner_state = self._get_impl(forward_batch).dispatch_a( hidden_states=hidden_states, + input_global_scale=input_global_scale, topk_idx=topk_idx, topk_weights=topk_weights, ) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index d72526a61..27a2ea950 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -80,6 +80,10 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var( USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var( "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM" ) +# TODO make it true by default when the DeepEP PR is merged +CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var( + "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false" +) # Supported activation schemes for the current configuration ACTIVATION_SCHEMES = ["static"] @@ -1234,6 +1238,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): w13_input_scale = _slice_scale(w13_input_scale) w2_input_scale = _slice_scale(w2_input_scale) + + if CUTEDSL_MOE_NVFP4_DISPATCH: + assert torch.all(w13_input_scale == w13_input_scale[0]) + w13_input_scale = w13_input_scale[0] else: w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) w2_input_scale = layer.w2_input_scale @@ -1476,7 +1484,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): out = flashinfer_cutedsl_moe_masked( hidden_states=x, - input_global_scale=layer.w13_input_scale_quant, + input_global_scale=( + None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant + ), w1=layer.w13_weight, w1_blockscale=layer.w13_blockscale_swizzled, w1_alpha=layer.g1_alphas, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 74d207103..b486740c3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -896,6 +896,7 @@ class DeepseekV2MoE(nn.Module): if self.ep_size > 1: self.experts.deepep_dispatcher.dispatch_a( hidden_states=state.hidden_states_mlp_input, + input_global_scale=None, topk_idx=state.pop("topk_idx_local"), topk_weights=state.pop("topk_weights_local"), forward_batch=state.forward_batch,