Refactor allreduce add rmsnorm pattern (#9278)
This commit is contained in:
@@ -34,6 +34,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_size,
|
||||
get_global_dp_buffer,
|
||||
get_local_dp_buffer,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.moe import (
|
||||
get_moe_a2a_backend,
|
||||
@@ -47,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
||||
|
||||
|
||||
class ScatterMode(Enum):
|
||||
"""
|
||||
@@ -162,11 +165,13 @@ class LayerCommunicator:
|
||||
post_attention_layernorm: torch.nn.Module,
|
||||
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
||||
allow_reduce_scatter: bool = False,
|
||||
is_last_layer: bool = False,
|
||||
):
|
||||
self.layer_scatter_modes = layer_scatter_modes
|
||||
self.input_layernorm = input_layernorm
|
||||
self.post_attention_layernorm = post_attention_layernorm
|
||||
self.allow_reduce_scatter = allow_reduce_scatter
|
||||
self.is_last_layer = is_last_layer
|
||||
|
||||
self._context = CommunicateContext.init_new()
|
||||
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
||||
@@ -264,6 +269,42 @@ class LayerCommunicator:
|
||||
and forward_batch.dp_padding_mode.is_max_len()
|
||||
)
|
||||
|
||||
def should_fuse_mlp_allreduce_with_next_layer(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> bool:
|
||||
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
|
||||
if (
|
||||
is_dp_attention_enabled()
|
||||
and speculative_algo is not None
|
||||
and speculative_algo.is_eagle()
|
||||
):
|
||||
return False
|
||||
|
||||
batch_size = (
|
||||
forward_batch.input_ids.shape[0]
|
||||
if hasattr(forward_batch, "input_ids")
|
||||
else 0
|
||||
)
|
||||
if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
|
||||
return False
|
||||
|
||||
static_conditions_met = (
|
||||
(not self.is_last_layer)
|
||||
and (self._context.tp_size > 1)
|
||||
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||
and _is_sm100_supported
|
||||
and _is_flashinfer_available
|
||||
)
|
||||
|
||||
if not static_conditions_met:
|
||||
return False
|
||||
|
||||
return (
|
||||
batch_size > 0
|
||||
and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
|
||||
and (not self.is_last_layer)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunicateContext:
|
||||
|
||||
Reference in New Issue
Block a user