[b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621)
This commit is contained in:
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
)
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
|
||||
class ScatterMode(Enum):
|
||||
@@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
if (
|
||||
_is_sm100_supported
|
||||
and _is_flashinfer_available
|
||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||
and hidden_states.shape[0] <= 1024
|
||||
):
|
||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
|
||||
202
python/sglang/srt/layers/flashinfer_comm_fusion.py
Normal file
202
python/sglang/srt/layers/flashinfer_comm_fusion.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_flashinfer_comm = None
|
||||
_workspace_manager = None
|
||||
|
||||
if is_flashinfer_available():
|
||||
try:
|
||||
import flashinfer.comm as comm
|
||||
|
||||
_flashinfer_comm = comm
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"flashinfer.comm is not available, falling back to standard "
|
||||
"implementation"
|
||||
)
|
||||
|
||||
|
||||
class FlashInferWorkspaceManager:
|
||||
def __init__(self):
|
||||
self.workspace_tensor = None
|
||||
self.ipc_handles = None
|
||||
self.world_size = None
|
||||
self.rank = None
|
||||
self.initialized = False
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
max_token_num: int,
|
||||
hidden_dim: int,
|
||||
group=None,
|
||||
use_fp32_lamport: bool = False,
|
||||
):
|
||||
"""Initialize workspace"""
|
||||
if self.initialized and self.world_size == world_size:
|
||||
return
|
||||
|
||||
if _flashinfer_comm is None:
|
||||
logger.warning(
|
||||
"FlashInfer comm not available, skipping workspace " "initialization"
|
||||
)
|
||||
return
|
||||
|
||||
self.cleanup()
|
||||
|
||||
self.ipc_handles, self.workspace_tensor = (
|
||||
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||
rank,
|
||||
world_size,
|
||||
max_token_num,
|
||||
hidden_dim,
|
||||
group=group,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
)
|
||||
)
|
||||
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.initialized = True
|
||||
|
||||
logger.info(
|
||||
f"FlashInfer workspace initialized for rank {rank}, "
|
||||
f"world_size {world_size}"
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up workspace"""
|
||||
if self.initialized and self.ipc_handles is not None:
|
||||
try:
|
||||
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
|
||||
self.ipc_handles, group=dist.group.WORLD
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
|
||||
finally:
|
||||
self.workspace_tensor = None
|
||||
self.ipc_handles = None
|
||||
self.initialized = False
|
||||
|
||||
|
||||
_workspace_manager = FlashInferWorkspaceManager()
|
||||
|
||||
|
||||
def ensure_workspace_initialized(
|
||||
max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
||||
):
|
||||
"""Ensure workspace is initialized"""
|
||||
if not is_flashinfer_available() or _flashinfer_comm is None:
|
||||
return False
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
if world_size <= 1:
|
||||
return False
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
if (
|
||||
not _workspace_manager.initialized
|
||||
or _workspace_manager.world_size != world_size
|
||||
):
|
||||
_workspace_manager.initialize(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
)
|
||||
|
||||
return _workspace_manager.initialized
|
||||
|
||||
|
||||
def flashinfer_allreduce_add_rmsnorm(
|
||||
input_tensor: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
max_token_num: int = 1024,
|
||||
use_oneshot: bool = True,
|
||||
trigger_completion_at_end: bool = False,
|
||||
fp32_acc: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Use FlashInfer's fused allreduce + residual + RMS norm operation
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor that needs allreduce
|
||||
residual: Residual tensor
|
||||
weight: RMS norm weight
|
||||
eps: RMS norm epsilon
|
||||
max_token_num: Maximum token number
|
||||
use_oneshot: Whether to use oneshot mode
|
||||
trigger_completion_at_end: Whether to trigger completion at end
|
||||
fp32_acc: Whether to use fp32 precision
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
|
||||
"""
|
||||
if not is_flashinfer_available() or _flashinfer_comm is None:
|
||||
logger.debug(
|
||||
"FlashInfer not available, falling back to standard " "implementation"
|
||||
)
|
||||
return None, None
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
if world_size <= 1:
|
||||
logger.debug("Single GPU, no need for allreduce fusion")
|
||||
return None, None
|
||||
|
||||
if not ensure_workspace_initialized(
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=input_tensor.shape[-1],
|
||||
use_fp32_lamport=(input_tensor.dtype == torch.float32),
|
||||
):
|
||||
logger.debug("FlashInfer workspace not available")
|
||||
return None, None
|
||||
|
||||
token_num, hidden_dim = input_tensor.shape
|
||||
|
||||
residual_out = torch.empty_like(residual)
|
||||
norm_out = torch.empty_like(input_tensor)
|
||||
|
||||
_flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=input_tensor,
|
||||
world_size=world_size,
|
||||
world_rank=dist.get_rank(),
|
||||
token_num=token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
workspace_ptrs=_workspace_manager.workspace_tensor,
|
||||
launch_with_pdl=True,
|
||||
use_oneshot=use_oneshot,
|
||||
trigger_completion_at_end=trigger_completion_at_end,
|
||||
fp32_acc=fp32_acc,
|
||||
pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
|
||||
allreduce_out=None,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=eps,
|
||||
scale_factor=None,
|
||||
layout_code=None,
|
||||
)
|
||||
|
||||
return norm_out, residual_out
|
||||
|
||||
|
||||
def cleanup_flashinfer_workspace():
|
||||
global _workspace_manager
|
||||
if _workspace_manager is not None:
|
||||
_workspace_manager.cleanup()
|
||||
@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
|
||||
else:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
def forward_with_allreduce_fusion(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Forward method with allreduce fusion, prioritizing flashinfer fused operations
|
||||
"""
|
||||
if residual is not None:
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.flashinfer_comm_fusion import (
|
||||
flashinfer_allreduce_add_rmsnorm,
|
||||
)
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
fused_result = flashinfer_allreduce_add_rmsnorm(
|
||||
input_tensor=x,
|
||||
residual=residual,
|
||||
weight=self.weight,
|
||||
eps=self.variance_epsilon,
|
||||
)
|
||||
if fused_result[0] is not None:
|
||||
return fused_result
|
||||
|
||||
return self.forward(x, residual)
|
||||
|
||||
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
def __init__(
|
||||
|
||||
@@ -85,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"deepep_mode",
|
||||
"enable_ep_moe",
|
||||
"enable_flashinfer_moe",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
"moe_dense_tp_size",
|
||||
"ep_dispatch_algorithm",
|
||||
"deepep_config",
|
||||
|
||||
@@ -157,6 +157,7 @@ class ServerArgs:
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
enable_flashinfer_moe: bool = False
|
||||
enable_flashinfer_allreduce_fusion: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
ep_num_redundant_experts: int = 0
|
||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
||||
@@ -1206,6 +1207,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-allreduce-fusion",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-deepep-moe",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user