diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 4af27ad69..a8b1223cc 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn: if hidden_states.shape[0] != 0: hidden_states = layernorm(hidden_states) else: + # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465 + # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True). if ( _is_sm100_supported and _is_flashinfer_available and hasattr(layernorm, "forward_with_allreduce_fusion") and global_server_args_dict["enable_flashinfer_allreduce_fusion"] - and hidden_states.shape[0] <= 1024 + and hidden_states.shape[0] <= 128 ): hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index fb78218c3..49b09bedd 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager() def ensure_workspace_initialized( - max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False + max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False ): """Ensure workspace is initialized""" if not is_flashinfer_available() or _flashinfer_comm is None: @@ -119,12 +119,12 @@ def ensure_workspace_initialized( return _workspace_manager.initialized -def flashinfer_allreduce_add_rmsnorm( +def flashinfer_allreduce_residual_rmsnorm( input_tensor: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - max_token_num: int = 1024, + max_token_num: int = 128, use_oneshot: bool = True, trigger_completion_at_end: bool = False, fp32_acc: bool = False, diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 78b4a0513..0ad32a380 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -174,11 +174,11 @@ class RMSNorm(CustomOp): if residual is not None: from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.flashinfer_comm_fusion import ( - flashinfer_allreduce_add_rmsnorm, + flashinfer_allreduce_residual_rmsnorm, ) if get_tensor_model_parallel_world_size() > 1: - fused_result = flashinfer_allreduce_add_rmsnorm( + fused_result = flashinfer_allreduce_residual_rmsnorm( input_tensor=x, residual=residual, weight=self.weight,