From 8e64140e35c15a626d199a0dfdd9cc7f956ab6cc Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 3 Jul 2025 10:36:20 +0800 Subject: [PATCH] [b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621) --- python/sglang/srt/layers/communicator.py | 20 +- .../srt/layers/flashinfer_comm_fusion.py | 202 ++++++++++++++++++ python/sglang/srt/layers/layernorm.py | 26 +++ python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/server_args.py | 6 + 5 files changed, 253 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/layers/flashinfer_comm_fusion.py diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 42d2ec2a3..4af27ad69 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, ) +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import is_cuda, is_flashinfer_available + +_is_flashinfer_available = is_flashinfer_available() +_is_sm100_supported = is_cuda() and is_sm100_supported() class ScatterMode(Enum): @@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn: if hidden_states.shape[0] != 0: hidden_states = layernorm(hidden_states) else: - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - hidden_states, residual = layernorm(hidden_states, residual) + if ( + _is_sm100_supported + and _is_flashinfer_available + and hasattr(layernorm, "forward_with_allreduce_fusion") + and global_server_args_dict["enable_flashinfer_allreduce_fusion"] + and hidden_states.shape[0] <= 1024 + ): + hidden_states, residual = layernorm.forward_with_allreduce_fusion( + hidden_states, residual + ) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states, residual = layernorm(hidden_states, residual) return hidden_states, residual @staticmethod diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py new file mode 100644 index 000000000..fb78218c3 --- /dev/null +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -0,0 +1,202 @@ +import logging +from typing import Tuple + +import torch +import torch.distributed as dist + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.utils import is_flashinfer_available + +logger = logging.getLogger(__name__) + +_flashinfer_comm = None +_workspace_manager = None + +if is_flashinfer_available(): + try: + import flashinfer.comm as comm + + _flashinfer_comm = comm + except ImportError: + logger.warning( + "flashinfer.comm is not available, falling back to standard " + "implementation" + ) + + +class FlashInferWorkspaceManager: + def __init__(self): + self.workspace_tensor = None + self.ipc_handles = None + self.world_size = None + self.rank = None + self.initialized = False + + def initialize( + self, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + group=None, + use_fp32_lamport: bool = False, + ): + """Initialize workspace""" + if self.initialized and self.world_size == world_size: + return + + if _flashinfer_comm is None: + logger.warning( + "FlashInfer comm not available, skipping workspace " "initialization" + ) + return + + self.cleanup() + + self.ipc_handles, self.workspace_tensor = ( + comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + max_token_num, + hidden_dim, + group=group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + self.world_size = world_size + self.rank = rank + self.initialized = True + + logger.info( + f"FlashInfer workspace initialized for rank {rank}, " + f"world_size {world_size}" + ) + + def cleanup(self): + """Clean up workspace""" + if self.initialized and self.ipc_handles is not None: + try: + _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( + self.ipc_handles, group=dist.group.WORLD + ) + except Exception as e: + logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") + finally: + self.workspace_tensor = None + self.ipc_handles = None + self.initialized = False + + +_workspace_manager = FlashInferWorkspaceManager() + + +def ensure_workspace_initialized( + max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False +): + """Ensure workspace is initialized""" + if not is_flashinfer_available() or _flashinfer_comm is None: + return False + + world_size = get_tensor_model_parallel_world_size() + if world_size <= 1: + return False + + rank = dist.get_rank() + + if ( + not _workspace_manager.initialized + or _workspace_manager.world_size != world_size + ): + _workspace_manager.initialize( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + use_fp32_lamport=use_fp32_lamport, + ) + + return _workspace_manager.initialized + + +def flashinfer_allreduce_add_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + max_token_num: int = 1024, + use_oneshot: bool = True, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Use FlashInfer's fused allreduce + residual + RMS norm operation + + Args: + input_tensor: Input tensor that needs allreduce + residual: Residual tensor + weight: RMS norm weight + eps: RMS norm epsilon + max_token_num: Maximum token number + use_oneshot: Whether to use oneshot mode + trigger_completion_at_end: Whether to trigger completion at end + fp32_acc: Whether to use fp32 precision + + Returns: + Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output) + """ + if not is_flashinfer_available() or _flashinfer_comm is None: + logger.debug( + "FlashInfer not available, falling back to standard " "implementation" + ) + return None, None + + world_size = get_tensor_model_parallel_world_size() + if world_size <= 1: + logger.debug("Single GPU, no need for allreduce fusion") + return None, None + + if not ensure_workspace_initialized( + max_token_num=max_token_num, + hidden_dim=input_tensor.shape[-1], + use_fp32_lamport=(input_tensor.dtype == torch.float32), + ): + logger.debug("FlashInfer workspace not available") + return None, None + + token_num, hidden_dim = input_tensor.shape + + residual_out = torch.empty_like(residual) + norm_out = torch.empty_like(input_tensor) + + _flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + world_size=world_size, + world_rank=dist.get_rank(), + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=_workspace_manager.workspace_tensor, + launch_with_pdl=True, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm), + allreduce_out=None, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=eps, + scale_factor=None, + layout_code=None, + ) + + return norm_out, residual_out + + +def cleanup_flashinfer_workspace(): + global _workspace_manager + if _workspace_manager is not None: + _workspace_manager.cleanup() diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 5d8106f17..78b4a0513 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -163,6 +163,32 @@ class RMSNorm(CustomOp): else: return self.forward_native(x, residual) + def forward_with_allreduce_fusion( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward method with allreduce fusion, prioritizing flashinfer fused operations + """ + if residual is not None: + from sglang.srt.distributed import get_tensor_model_parallel_world_size + from sglang.srt.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_add_rmsnorm, + ) + + if get_tensor_model_parallel_world_size() > 1: + fused_result = flashinfer_allreduce_add_rmsnorm( + input_tensor=x, + residual=residual, + weight=self.weight, + eps=self.variance_epsilon, + ) + if fused_result[0] is not None: + return fused_result + + return self.forward(x, residual) + class GemmaRMSNorm(CustomOp): def __init__( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 864eaf1ee..b257fe6ef 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -85,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "deepep_mode", "enable_ep_moe", "enable_flashinfer_moe", + "enable_flashinfer_allreduce_fusion", "moe_dense_tp_size", "ep_dispatch_algorithm", "deepep_config", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0fb3c6af9..76e0272a8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -157,6 +157,7 @@ class ServerArgs: enable_ep_moe: bool = False enable_deepep_moe: bool = False enable_flashinfer_moe: bool = False + enable_flashinfer_allreduce_fusion: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None @@ -1206,6 +1207,11 @@ class ServerArgs: action="store_true", help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe", ) + parser.add_argument( + "--enable-flashinfer-allreduce-fusion", + action="store_true", + help="Enable FlashInfer allreduce fusion for Add_RMSNorm.", + ) parser.add_argument( "--enable-deepep-moe", action="store_true",