Refactor allreduce add rmsnorm pattern (#9278)

This commit is contained in:
Xiaoyu Zhang
2025-08-20 17:03:08 +08:00
committed by GitHub
parent 08ebdf79d0
commit f96413c444
3 changed files with 52 additions and 78 deletions

View File

@@ -34,6 +34,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
get_global_dp_buffer,
get_local_dp_buffer,
is_dp_attention_enabled,
)
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
@@ -47,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
class ScatterMode(Enum):
"""
@@ -162,11 +165,13 @@ class LayerCommunicator:
post_attention_layernorm: torch.nn.Module,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter: bool = False,
is_last_layer: bool = False,
):
self.layer_scatter_modes = layer_scatter_modes
self.input_layernorm = input_layernorm
self.post_attention_layernorm = post_attention_layernorm
self.allow_reduce_scatter = allow_reduce_scatter
self.is_last_layer = is_last_layer
self._context = CommunicateContext.init_new()
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
@@ -264,6 +269,42 @@ class LayerCommunicator:
and forward_batch.dp_padding_mode.is_max_len()
)
def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch
) -> bool:
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
if (
is_dp_attention_enabled()
and speculative_algo is not None
and speculative_algo.is_eagle()
):
return False
batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
return False
static_conditions_met = (
(not self.is_last_layer)
and (self._context.tp_size > 1)
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return False
return (
batch_size > 0
and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
and (not self.is_last_layer)
)
@dataclass
class CommunicateContext: