From 88fbc31b50be3f2ef68bff42b39cbf4aa09ca8b3 Mon Sep 17 00:00:00 2001 From: strgrb Date: Thu, 21 Aug 2025 07:54:30 +0800 Subject: [PATCH] Support trtllm_allreduce_fusion in flashinfer for cuda<12.8 (#9339) Co-authored-by: Zhang Kaihong --- python/sglang/srt/layers/communicator.py | 1 - .../srt/layers/flashinfer_comm_fusion.py | 30 ++++++++++++++++++- python/sglang/srt/layers/layernorm.py | 9 +++++- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 3f8973830..6e578afe0 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -292,7 +292,6 @@ class LayerCommunicator: (not self.is_last_layer) and (self._context.tp_size > 1) and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) - and _is_sm100_supported and _is_flashinfer_available ) diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index 023db709c..81280db0a 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -5,7 +5,11 @@ import torch import torch.distributed as dist from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import ( + direct_register_custom_op, + is_flashinfer_available, + supports_custom_op, +) logger = logging.getLogger(__name__) @@ -196,6 +200,30 @@ def flashinfer_allreduce_residual_rmsnorm( return norm_out, residual_out +def fake_flashinfer_allreduce_residual_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + max_token_num: int = 2048, + use_oneshot: Optional[bool] = None, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + residual_out = torch.empty_like(residual) + norm_out = torch.empty_like(input_tensor) + return norm_out, residual_out + + +if supports_custom_op(): + direct_register_custom_op( + "flashinfer_allreduce_residual_rmsnorm", + flashinfer_allreduce_residual_rmsnorm, + mutates_args=["input_tensor", "residual", "weight"], + fake_impl=fake_flashinfer_allreduce_residual_rmsnorm, + ) + + def cleanup_flashinfer_workspace(): global _workspace_manager if _workspace_manager is not None: diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 4c1f2268b..a77747351 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -27,6 +27,7 @@ from sglang.srt.utils import ( is_cuda, is_hip, is_npu, + supports_custom_op, ) _is_cuda = is_cuda() @@ -202,8 +203,14 @@ class RMSNorm(CustomOp): flashinfer_allreduce_residual_rmsnorm, ) + fused_op = ( + torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm + if supports_custom_op() + else flashinfer_allreduce_residual_rmsnorm + ) + if get_tensor_model_parallel_world_size() > 1: - fused_result = flashinfer_allreduce_residual_rmsnorm( + fused_result = fused_op( input_tensor=x, residual=residual, weight=self.weight,