From faa25df1ae6c33162dacd31153f8c56b9da1d2db Mon Sep 17 00:00:00 2001 From: eigen <52445717+yyihuang@users.noreply.github.com> Date: Sat, 9 Aug 2025 03:51:27 -0400 Subject: [PATCH] feat: update flashinfer ar oneshot params (#8687) --- python/sglang/srt/layers/communicator.py | 1 - python/sglang/srt/layers/flashinfer_comm_fusion.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index a497db464..95bf1514c 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -443,7 +443,6 @@ class CommunicateWithAllReduceAndLayerNormFn: and _is_flashinfer_available and hasattr(layernorm, "forward_with_allreduce_fusion") and global_server_args_dict["enable_flashinfer_allreduce_fusion"] - and hidden_states.shape[0] <= 128 ): hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index 49b09bedd..1869beb56 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -125,7 +125,7 @@ def flashinfer_allreduce_residual_rmsnorm( weight: torch.Tensor, eps: float = 1e-6, max_token_num: int = 128, - use_oneshot: bool = True, + use_oneshot: Optional[bool] = None, trigger_completion_at_end: bool = False, fp32_acc: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: