[b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621)
This commit is contained in:
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||||
|
|
||||||
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
|
|
||||||
|
|
||||||
class ScatterMode(Enum):
|
class ScatterMode(Enum):
|
||||||
@@ -396,6 +401,17 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
dp_scatter(residual, hidden_states, forward_batch)
|
dp_scatter(residual, hidden_states, forward_batch)
|
||||||
if hidden_states.shape[0] != 0:
|
if hidden_states.shape[0] != 0:
|
||||||
hidden_states = layernorm(hidden_states)
|
hidden_states = layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
_is_sm100_supported
|
||||||
|
and _is_flashinfer_available
|
||||||
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||||
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||||
|
and hidden_states.shape[0] <= 1024
|
||||||
|
):
|
||||||
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||||
hidden_states, residual = layernorm(hidden_states, residual)
|
hidden_states, residual = layernorm(hidden_states, residual)
|
||||||
|
|||||||
202
python/sglang/srt/layers/flashinfer_comm_fusion.py
Normal file
202
python/sglang/srt/layers/flashinfer_comm_fusion.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_flashinfer_comm = None
|
||||||
|
_workspace_manager = None
|
||||||
|
|
||||||
|
if is_flashinfer_available():
|
||||||
|
try:
|
||||||
|
import flashinfer.comm as comm
|
||||||
|
|
||||||
|
_flashinfer_comm = comm
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"flashinfer.comm is not available, falling back to standard "
|
||||||
|
"implementation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferWorkspaceManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.workspace_tensor = None
|
||||||
|
self.ipc_handles = None
|
||||||
|
self.world_size = None
|
||||||
|
self.rank = None
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
world_size: int,
|
||||||
|
rank: int,
|
||||||
|
max_token_num: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
group=None,
|
||||||
|
use_fp32_lamport: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize workspace"""
|
||||||
|
if self.initialized and self.world_size == world_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
if _flashinfer_comm is None:
|
||||||
|
logger.warning(
|
||||||
|
"FlashInfer comm not available, skipping workspace " "initialization"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.cleanup()
|
||||||
|
|
||||||
|
self.ipc_handles, self.workspace_tensor = (
|
||||||
|
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
max_token_num,
|
||||||
|
hidden_dim,
|
||||||
|
group=group,
|
||||||
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.world_size = world_size
|
||||||
|
self.rank = rank
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"FlashInfer workspace initialized for rank {rank}, "
|
||||||
|
f"world_size {world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up workspace"""
|
||||||
|
if self.initialized and self.ipc_handles is not None:
|
||||||
|
try:
|
||||||
|
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
|
||||||
|
self.ipc_handles, group=dist.group.WORLD
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
|
||||||
|
finally:
|
||||||
|
self.workspace_tensor = None
|
||||||
|
self.ipc_handles = None
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
_workspace_manager = FlashInferWorkspaceManager()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_workspace_initialized(
|
||||||
|
max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
||||||
|
):
|
||||||
|
"""Ensure workspace is initialized"""
|
||||||
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
world_size = get_tensor_model_parallel_world_size()
|
||||||
|
if world_size <= 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
if (
|
||||||
|
not _workspace_manager.initialized
|
||||||
|
or _workspace_manager.world_size != world_size
|
||||||
|
):
|
||||||
|
_workspace_manager.initialize(
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
max_token_num=max_token_num,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _workspace_manager.initialized
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_allreduce_add_rmsnorm(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
max_token_num: int = 1024,
|
||||||
|
use_oneshot: bool = True,
|
||||||
|
trigger_completion_at_end: bool = False,
|
||||||
|
fp32_acc: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Use FlashInfer's fused allreduce + residual + RMS norm operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: Input tensor that needs allreduce
|
||||||
|
residual: Residual tensor
|
||||||
|
weight: RMS norm weight
|
||||||
|
eps: RMS norm epsilon
|
||||||
|
max_token_num: Maximum token number
|
||||||
|
use_oneshot: Whether to use oneshot mode
|
||||||
|
trigger_completion_at_end: Whether to trigger completion at end
|
||||||
|
fp32_acc: Whether to use fp32 precision
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
|
||||||
|
"""
|
||||||
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
||||||
|
logger.debug(
|
||||||
|
"FlashInfer not available, falling back to standard " "implementation"
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
world_size = get_tensor_model_parallel_world_size()
|
||||||
|
if world_size <= 1:
|
||||||
|
logger.debug("Single GPU, no need for allreduce fusion")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not ensure_workspace_initialized(
|
||||||
|
max_token_num=max_token_num,
|
||||||
|
hidden_dim=input_tensor.shape[-1],
|
||||||
|
use_fp32_lamport=(input_tensor.dtype == torch.float32),
|
||||||
|
):
|
||||||
|
logger.debug("FlashInfer workspace not available")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
token_num, hidden_dim = input_tensor.shape
|
||||||
|
|
||||||
|
residual_out = torch.empty_like(residual)
|
||||||
|
norm_out = torch.empty_like(input_tensor)
|
||||||
|
|
||||||
|
_flashinfer_comm.trtllm_allreduce_fusion(
|
||||||
|
allreduce_in=input_tensor,
|
||||||
|
world_size=world_size,
|
||||||
|
world_rank=dist.get_rank(),
|
||||||
|
token_num=token_num,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
workspace_ptrs=_workspace_manager.workspace_tensor,
|
||||||
|
launch_with_pdl=True,
|
||||||
|
use_oneshot=use_oneshot,
|
||||||
|
trigger_completion_at_end=trigger_completion_at_end,
|
||||||
|
fp32_acc=fp32_acc,
|
||||||
|
pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
|
||||||
|
allreduce_out=None,
|
||||||
|
residual_in=residual,
|
||||||
|
residual_out=residual_out,
|
||||||
|
norm_out=norm_out,
|
||||||
|
quant_out=None,
|
||||||
|
scale_out=None,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=eps,
|
||||||
|
scale_factor=None,
|
||||||
|
layout_code=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return norm_out, residual_out
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_flashinfer_workspace():
|
||||||
|
global _workspace_manager
|
||||||
|
if _workspace_manager is not None:
|
||||||
|
_workspace_manager.cleanup()
|
||||||
@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
|
|||||||
else:
|
else:
|
||||||
return self.forward_native(x, residual)
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
|
def forward_with_allreduce_fusion(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Forward method with allreduce fusion, prioritizing flashinfer fused operations
|
||||||
|
"""
|
||||||
|
if residual is not None:
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from sglang.srt.layers.flashinfer_comm_fusion import (
|
||||||
|
flashinfer_allreduce_add_rmsnorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
|
fused_result = flashinfer_allreduce_add_rmsnorm(
|
||||||
|
input_tensor=x,
|
||||||
|
residual=residual,
|
||||||
|
weight=self.weight,
|
||||||
|
eps=self.variance_epsilon,
|
||||||
|
)
|
||||||
|
if fused_result[0] is not None:
|
||||||
|
return fused_result
|
||||||
|
|
||||||
|
return self.forward(x, residual)
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(CustomOp):
|
class GemmaRMSNorm(CustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"deepep_mode",
|
"deepep_mode",
|
||||||
"enable_ep_moe",
|
"enable_ep_moe",
|
||||||
"enable_flashinfer_moe",
|
"enable_flashinfer_moe",
|
||||||
|
"enable_flashinfer_allreduce_fusion",
|
||||||
"moe_dense_tp_size",
|
"moe_dense_tp_size",
|
||||||
"ep_dispatch_algorithm",
|
"ep_dispatch_algorithm",
|
||||||
"deepep_config",
|
"deepep_config",
|
||||||
|
|||||||
@@ -157,6 +157,7 @@ class ServerArgs:
|
|||||||
enable_ep_moe: bool = False
|
enable_ep_moe: bool = False
|
||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
enable_flashinfer_moe: bool = False
|
enable_flashinfer_moe: bool = False
|
||||||
|
enable_flashinfer_allreduce_fusion: bool = False
|
||||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||||
ep_num_redundant_experts: int = 0
|
ep_num_redundant_experts: int = 0
|
||||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
||||||
@@ -1206,6 +1207,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-flashinfer-allreduce-fusion",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user