From eff4eb3fdd81a82ac76ae16250d06c48fd7ab03e Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 15 Aug 2025 22:08:11 -0700 Subject: [PATCH] Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667) --- .../device_communicators/pynccl.py | 86 +++++++++++++++---- .../device_communicators/pynccl_wrapper.py | 52 +++++++++++ .../sglang/srt/distributed/parallel_state.py | 81 +++++++++++++++++ python/sglang/srt/layers/communicator.py | 11 ++- python/sglang/srt/layers/dp_attention.py | 25 +++++- python/sglang/srt/layers/logits_processor.py | 6 +- python/sglang/srt/layers/moe/__init__.py | 2 + .../srt/layers/moe/fused_moe_triton/layer.py | 5 +- python/sglang/srt/layers/moe/topk.py | 7 ++ python/sglang/srt/layers/moe/utils.py | 23 +++++ .../srt/layers/quantization/modelopt_quant.py | 47 ++++++++-- python/sglang/srt/managers/schedule_batch.py | 1 + .../srt/model_executor/forward_batch_info.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 51 +++++++---- python/sglang/srt/operations.py | 7 +- python/sglang/srt/server_args.py | 6 ++ 16 files changed, 360 insertions(+), 52 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index 81dd81780..fbb59c477 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -148,7 +148,11 @@ class PyNcclCommunicator: ) def all_gather( - self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None, + sizes: Optional[list[int]] = None, ): if self.disabled: return @@ -161,14 +165,33 @@ class PyNcclCommunicator: ) if stream is None: stream = self.stream - self.nccl.ncclAllGather( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), - input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - self.comm, - cudaStream_t(stream.cuda_stream), - ) + + if sizes is not None: + split_offset = 0 + + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + dst_slice = output_tensor[split_offset : split_offset + split_size] + self.nccl.ncclBroadcast( + buffer_type(input_tensor.data_ptr()), + buffer_type(dst_slice.data_ptr()), + dst_slice.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + split_offset += split_size + self.nccl.ncclGroupEnd() + else: + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) def reduce_scatter( self, @@ -176,6 +199,7 @@ class PyNcclCommunicator: input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None, + sizes: Optional[list[int]] = None, ): if self.disabled: return @@ -188,15 +212,35 @@ class PyNcclCommunicator: ) if stream is None: stream = self.stream - self.nccl.ncclReduceScatter( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), - output_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), - self.comm, - cudaStream_t(stream.cuda_stream), - ) + + if sizes is not None: + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + chunk = input_tensor[split_offset : split_offset + split_size, ...] + + self.nccl.ncclReduce( + buffer_type(chunk.data_ptr()), + buffer_type(output_tensor.data_ptr()), + chunk.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + split_offset += split_size + self.nccl.ncclGroupEnd() + else: + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: @@ -266,6 +310,12 @@ class PyNcclCommunicator: def deregister_comm_window(self, window): return self.nccl.ncclCommWindowDeregister(self.comm, window) + def group_start(self): + self.nccl.ncclGroupStart() + + def group_end(self): + self.nccl.ncclGroupEnd() + @contextmanager def change_state( self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index cad39624e..579811777 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -206,6 +206,26 @@ class NCCLLibrary: cudaStream_t, ], ), + # ncclResult_t ncclReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, int root, + # ncclComm_t comm, cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclReduceScatter( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, @@ -278,6 +298,10 @@ class NCCLLibrary: # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + # ncclResult_t ncclGroupStart(); + Function("ncclGroupStart", ncclResult_t, []), + # ncclResult_t ncclGroupEnd(); + Function("ncclGroupEnd", ncclResult_t, []), ] exported_functions_symm_mem = [ @@ -400,6 +424,28 @@ class NCCLLibrary: ) ) + def ncclReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclReduce"]( + sendbuff, recvbuff, count, datatype, op, root, comm, stream + ) + ) + def ncclReduceScatter( self, sendbuff: buffer_type, @@ -499,6 +545,12 @@ class NCCLLibrary: def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) + def ncclGroupStart(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupStart"]()) + + def ncclGroupEnd(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + __all__ = [ "NCCLLibrary", diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index adb43158f..286618d6b 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -583,6 +583,39 @@ class GroupCoordinator: torch.distributed.reduce_scatter(output, input_list, group=self.device_group) return output + def reduce_scatterv( + self, + input_: torch.Tensor, + output: Optional[torch.Tensor] = None, + sizes: Optional[List[int]] = None, + ) -> torch.Tensor: + world_size = self.world_size + pynccl_comm = self.pynccl_comm + + with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()): + assert ( + pynccl_comm is not None and not pynccl_comm.disabled + ), "pynccl is required for reduce_scatterv" + + if sizes is not None: + assert len(sizes) == world_size + assert input_.shape[0] == sum(sizes) + chunk_size = sizes[self.rank_in_group] + else: + assert input_.shape[0] % world_size == 0 + chunk_size = input_.shape[0] // world_size + output_shape = (chunk_size,) + input_.shape[1:] + + if output is None: + output = torch.empty( + output_shape, dtype=input_.dtype, device=input_.device + ) + else: + assert output.shape == output_shape + + pynccl_comm.reduce_scatter(output, input_, sizes=sizes) + return output + def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: @@ -673,6 +706,54 @@ class GroupCoordinator: ) return output_tensor + def all_gatherv( + self, + input_: Union[torch.Tensor, List[torch.Tensor]], + sizes: Optional[List[int]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Supports varying sizes per rank and input tensor list. + `sizes`: a list of len(world_size) with the number of items per rank to gather. + """ + world_size = self.world_size + pynccl_comm = self.pynccl_comm + + with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()): + assert ( + pynccl_comm is not None and not pynccl_comm.disabled + ), "pynccl is required for all_gatherv" + + def _all_gather_single( + input_: torch.Tensor, sizes: Optional[List[int]] = None + ): + input_size = input_.size() + if sizes is not None: + assert len(sizes) == world_size + assert input_.shape[0] == sizes[self.rank_in_group] + output_size = (sum(sizes),) + input_size[1:] + # 'sizes' is not needed if all inputs in the same group have the same shape + if all(s == sizes[0] for s in sizes): + sizes = None + else: + output_size = (input_size[0] * world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + pynccl_comm.all_gather(output_tensor, input_, sizes=sizes) + return output_tensor + + if isinstance(input_, torch.Tensor): + return _all_gather_single(input_, sizes) + + output_list = [] + pynccl_comm.group_start() + for inp in input_: + output_list.append(_all_gather_single(inp, sizes=sizes)) + pynccl_comm.group_end() + + return output_list + def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 27a1721aa..73a9030f7 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -35,7 +35,10 @@ from sglang.srt.layers.dp_attention import ( get_global_dp_buffer, get_local_dp_buffer, ) -from sglang.srt.layers.moe import get_moe_a2a_backend +from sglang.srt.layers.moe import ( + get_moe_a2a_backend, + should_use_flashinfer_cutlass_moe_fp4_allgather, +) from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -112,7 +115,11 @@ class LayerScatterModes: if context.is_layer_sparse: return ( ScatterMode.SCATTERED - if not get_moe_a2a_backend().is_none() + if ( + # Token dispatch/combine will be handled outside of LayerCommunicator for these modes. + not get_moe_a2a_backend().is_none() + or should_use_flashinfer_cutlass_moe_fp4_allgather() + ) else ScatterMode.FULL ) else: diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 3d5d30890..58f6e0f9c 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper: _device: torch.device _global_dp_buffer_len: int _local_dp_buffer_len: int + _global_num_tokens: Optional[List[int]] @classmethod def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device): @@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper: cls._device = device @classmethod - def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int): + def set_dp_buffer_len( + cls, + global_dp_buffer_len: int, + local_dp_buffer_len: int, + global_num_tokens: Optional[List[int]] = None, + ): cls._global_dp_buffer_len = global_dp_buffer_len cls._local_dp_buffer_len = local_dp_buffer_len + cls._global_num_tokens = global_num_tokens @classmethod def get_global_dp_buffer(cls) -> torch.Tensor: @@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper: def get_local_dp_buffer_len(cls) -> int: return cls._local_dp_buffer_len + @classmethod + def get_dp_global_num_tokens(cls) -> List[int]: + return cls._global_num_tokens -def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int): + +def set_dp_buffer_len( + global_dp_buffer_len: int, + local_dp_buffer_len: int, + global_num_tokens: Optional[List[int]] = None, +): _DpGatheredBufferWrapper.set_dp_buffer_len( - global_dp_buffer_len, local_dp_buffer_len + global_dp_buffer_len, local_dp_buffer_len, global_num_tokens ) @@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int: return _DpGatheredBufferWrapper.get_local_dp_buffer_len() +def get_dp_global_num_tokens() -> List[int]: + return _DpGatheredBufferWrapper.get_dp_global_num_tokens() + + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 711aba03f..00b30a848 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -191,7 +191,11 @@ class LogitsMetadata: else: self.global_dp_buffer_len = self.global_dp_buffer_len - set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens) + set_dp_buffer_len( + self.global_dp_buffer_len, + self.dp_local_num_tokens, + self.global_num_tokens_for_logprob_cpu, + ) class LogitsProcessor(nn.Module): diff --git a/python/sglang/srt/layers/moe/__init__.py b/python/sglang/srt/layers/moe/__init__.py index 88bdb5787..e5e5930a2 100644 --- a/python/sglang/srt/layers/moe/__init__.py +++ b/python/sglang/srt/layers/moe/__init__.py @@ -10,6 +10,7 @@ from sglang.srt.layers.moe.utils import ( get_tbo_token_distribution_threshold, initialize_moe_config, is_tbo_enabled, + should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_trtllm_moe, ) @@ -23,6 +24,7 @@ __all__ = [ "get_moe_runner_backend", "get_deepep_mode", "should_use_flashinfer_trtllm_moe", + "should_use_flashinfer_cutlass_moe_fp4_allgather", "is_tbo_enabled", "get_tbo_token_distribution_threshold", "get_deepep_config", diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 46473ac4c..c5b314988 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight @@ -621,9 +622,7 @@ class FusedMoE(torch.nn.Module): if "ModelOpt" in self.quant_method.__class__.__name__: # Determine per-tensor weight scale patterns based on variant - is_fp4_variant = ( - "ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__ - ) + is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor per_tensor_conditions = ( diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 3df33898a..3b939bca8 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -327,6 +327,13 @@ class TopK(CustomOp): expert_location_dispatch_info=expert_location_dispatch_info, ) + def empty_topk_output(self, device: torch.device) -> TopKOutput: + topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts + topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) + topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device) + router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) + return StandardTopKOutput(topk_weights, topk_idx, router_logits) + # ------------------------------- TopK implementation ------------------------------------- diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 40bd10e23..2fbab220f 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -7,6 +7,11 @@ from typing import TYPE_CHECKING, Optional from packaging import version as pkg_version +from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size +from sglang.srt.layers.dp_attention import ( + get_attention_dp_size, + is_dp_attention_enabled, +) from sglang.srt.utils import logger if TYPE_CHECKING: @@ -99,6 +104,7 @@ DEEPEP_MODE: Optional[DeepEPMode] = None IS_TBO_ENABLED: Optional[bool] = None TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None DEEPEP_CONFIG: Optional[str] = None +DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None def initialize_moe_config(server_args: ServerArgs): @@ -108,6 +114,7 @@ def initialize_moe_config(server_args: ServerArgs): global DEEPEP_CONFIG global IS_TBO_ENABLED global TBO_TOKEN_DISTRIBUTION_THRESHOLD + global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend) MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend) @@ -115,6 +122,9 @@ def initialize_moe_config(server_args: ServerArgs): DEEPEP_CONFIG = server_args.deepep_config or "" IS_TBO_ENABLED = server_args.enable_two_batch_overlap TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold + DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = ( + server_args.disable_flashinfer_cutlass_moe_fp4_allgather + ) def get_moe_a2a_backend() -> MoeA2ABackend: @@ -175,3 +185,16 @@ def should_use_flashinfer_trtllm_moe(): >= pkg_version.parse("0.2.9rc1") ) return result + + +@lru_cache(maxsize=1) +def should_use_flashinfer_cutlass_moe_fp4_allgather(): + """ + Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving. + """ + return ( + not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER + and get_moe_runner_backend().is_flashinfer_cutlass() + and is_dp_attention_enabled() + and get_moe_expert_parallel_world_size() == get_attention_dp_size() + ) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index a77d504a2..7647ec30b 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe +from sglang.srt.distributed import get_tp_group +from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer +from sglang.srt.layers.moe import ( + should_use_flashinfer_cutlass_moe_fp4_allgather, + should_use_flashinfer_trtllm_moe, +) from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( @@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ), "apply_router_weight_on_input is not supported for Flashinfer" # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # and fp4 quantized weights loaded from the checkpoint - topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + output_dtype = x.dtype + x_sf = None + if should_use_flashinfer_cutlass_moe_fp4_allgather(): + from flashinfer import fp4_quantize, nvfp4_block_scale_interleave + + # Quantize before comm, swizzle after. + if x.shape[0] > 0: + x, x_sf = fp4_quantize( + x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False + ) + else: + x_col = x.shape[1] + x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device) + x_sf = torch.zeros( + 0, x_col // 16, dtype=torch.uint8, device=x.device + ) + topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv( + [topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens() + ) + x_sf = nvfp4_block_scale_interleave(x_sf) + output = flashinfer_cutlass_fused_moe( - x, - topk_ids.to(torch.int), - topk_weights, - layer.w13_weight.view(torch.long), - layer.w2_weight.view(torch.long), - x.dtype, + input=x, + token_selected_experts=topk_ids.to(torch.int), + token_final_scales=topk_weights, + fc1_expert_weights=layer.w13_weight.view(torch.long), + fc2_expert_weights=layer.w2_weight.view(torch.long), + output_dtype=output_dtype, + input_sf=x_sf, quant_scales=[ layer.w13_input_scale_quant, layer.w13_blockscale_swizzled.view(torch.int32), @@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): )[0] if moe_runner_config.routed_scaling_factor is not None: output *= moe_runner_config.routed_scaling_factor + if should_use_flashinfer_cutlass_moe_fp4_allgather(): + output, global_output = get_local_dp_buffer(), output + get_tp_group().reduce_scatterv( + global_output, output=output, sizes=get_dp_global_num_tokens() + ) return output from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 770fd8cee..5b45154db 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "chunked_prefill_size", "device", "disable_chunked_prefix_cache", + "disable_flashinfer_cutlass_moe_fp4_allgather", "disable_radix_cache", "enable_dp_lm_head", "enable_flashinfer_allreduce_fusion", diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index da2d81fc5..bceb0759e 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -649,7 +649,7 @@ class ForwardBatch: num_tokens = global_num_tokens[0] self.global_dp_buffer_len = buffer_len - set_dp_buffer_len(buffer_len, num_tokens) + set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens) bs = self.batch_size diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2ba57f958..2e0612b78 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -60,7 +60,11 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend +from sglang.srt.layers.moe import ( + get_deepep_mode, + get_moe_a2a_backend, + should_use_flashinfer_cutlass_moe_fp4_allgather, +) from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK @@ -343,7 +347,7 @@ class DeepseekV2MoE(nn.Module): self.shared_experts_weight_block_size = None if config.n_shared_experts is not None and self.num_fused_shared_experts == 0: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - # disable tp for shared experts when enable deepep moe + # disable tp for shared experts when enable deepep moe, or with fp4 allgather self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, @@ -354,6 +358,7 @@ class DeepseekV2MoE(nn.Module): **( dict(tp_rank=0, tp_size=1) if get_moe_a2a_backend().is_deepep() + or should_use_flashinfer_cutlass_moe_fp4_allgather() else {} ), ) @@ -433,14 +438,19 @@ class DeepseekV2MoE(nn.Module): if ( self.alt_stream is not None and self.num_fused_shared_experts == 0 + and hidden_states.shape[0] > 0 and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD ): return self.forward_normal_dual_stream( - hidden_states, should_allreduce_fusion, use_reduce_scatter + hidden_states, + should_allreduce_fusion, + use_reduce_scatter, ) else: return self.forward_normal( - hidden_states, should_allreduce_fusion, use_reduce_scatter + hidden_states, + should_allreduce_fusion, + use_reduce_scatter, ) else: return self.forward_deepep(hidden_states, forward_batch) @@ -471,7 +481,12 @@ class DeepseekV2MoE(nn.Module): torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out sm.tag(final_hidden_states) - if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter: + if ( + self.tp_size > 1 + and not should_allreduce_fusion + and not use_reduce_scatter + and not should_use_flashinfer_cutlass_moe_fp4_allgather() + ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states @@ -486,10 +501,14 @@ class DeepseekV2MoE(nn.Module): ): return self.forward_cpu(hidden_states, should_allreduce_fusion) - shared_output = self._forward_shared_experts(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) - topk_output = self.topk(hidden_states, router_logits) + if hidden_states.shape[0] > 0: + shared_output = self._forward_shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) + else: + shared_output = None + topk_output = self.topk.empty_topk_output(hidden_states.device) final_hidden_states = self.experts(hidden_states, topk_output) if not _is_cuda and not _use_aiter: @@ -501,7 +520,12 @@ class DeepseekV2MoE(nn.Module): torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out sm.tag(final_hidden_states) - if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter: + if ( + self.tp_size > 1 + and not should_allreduce_fusion + and not use_reduce_scatter + and not should_use_flashinfer_cutlass_moe_fp4_allgather() + ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states @@ -580,11 +604,8 @@ class DeepseekV2MoE(nn.Module): ), ) else: - topk_idx = torch.full( - (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device - ) - topk_weights = torch.empty( - (0, self.top_k), dtype=torch.float32, device=hidden_states.device + topk_weights, topk_idx, _ = self.topk.empty_topk_output( + hidden_states.device ) final_hidden_states = self.experts( diff --git a/python/sglang/srt/operations.py b/python/sglang/srt/operations.py index f850bcd25..f8730cd77 100644 --- a/python/sglang/srt/operations.py +++ b/python/sglang/srt/operations.py @@ -84,6 +84,7 @@ class _StageExecutor: forward_batch: ForwardBatch = inputs["forward_batch"] self._global_dp_buffer_len = forward_batch.global_dp_buffer_len self._local_dp_buffer_len = forward_batch.input_ids.shape[0] + self._global_num_tokens = forward_batch.global_num_tokens_cpu def next(self): assert not self.done @@ -91,7 +92,11 @@ class _StageExecutor: stage = self._stages[self._index] if self._global_dp_buffer_len is not None: - set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len) + set_dp_buffer_len( + self._global_dp_buffer_len, + self._local_dp_buffer_len, + self._global_num_tokens, + ) with _annotate_region(debug_name=f"{self._debug_name}{self._index}"): for op in stage: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0edc3ca08..fd2bd1580 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -230,6 +230,7 @@ class ServerArgs: enable_cudagraph_gc: bool = False enable_nccl_nvls: bool = False enable_symm_mem: bool = False + disable_flashinfer_cutlass_moe_fp4_allgather: bool = False enable_tokenizer_batch_encode: bool = False disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False @@ -1714,6 +1715,11 @@ class ServerArgs: action="store_true", help="Enable NCCL symmetric memory for fast collectives.", ) + parser.add_argument( + "--disable-flashinfer-cutlass-moe-fp4-allgather", + action="store_true", + help="Disables quantize before all-gather for flashinfer cutlass moe.", + ) parser.add_argument( "--enable-tokenizer-batch-encode", action="store_true",