Refactor allreduce add rmsnorm pattern (#9278)
This commit is contained in:
@@ -34,6 +34,7 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_global_dp_buffer,
|
get_global_dp_buffer,
|
||||||
get_local_dp_buffer,
|
get_local_dp_buffer,
|
||||||
|
is_dp_attention_enabled,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe import (
|
from sglang.srt.layers.moe import (
|
||||||
get_moe_a2a_backend,
|
get_moe_a2a_backend,
|
||||||
@@ -47,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
|
|||||||
_is_flashinfer_available = is_flashinfer_available()
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
|
|
||||||
|
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
||||||
|
|
||||||
|
|
||||||
class ScatterMode(Enum):
|
class ScatterMode(Enum):
|
||||||
"""
|
"""
|
||||||
@@ -162,11 +165,13 @@ class LayerCommunicator:
|
|||||||
post_attention_layernorm: torch.nn.Module,
|
post_attention_layernorm: torch.nn.Module,
|
||||||
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
||||||
allow_reduce_scatter: bool = False,
|
allow_reduce_scatter: bool = False,
|
||||||
|
is_last_layer: bool = False,
|
||||||
):
|
):
|
||||||
self.layer_scatter_modes = layer_scatter_modes
|
self.layer_scatter_modes = layer_scatter_modes
|
||||||
self.input_layernorm = input_layernorm
|
self.input_layernorm = input_layernorm
|
||||||
self.post_attention_layernorm = post_attention_layernorm
|
self.post_attention_layernorm = post_attention_layernorm
|
||||||
self.allow_reduce_scatter = allow_reduce_scatter
|
self.allow_reduce_scatter = allow_reduce_scatter
|
||||||
|
self.is_last_layer = is_last_layer
|
||||||
|
|
||||||
self._context = CommunicateContext.init_new()
|
self._context = CommunicateContext.init_new()
|
||||||
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
||||||
@@ -264,6 +269,42 @@ class LayerCommunicator:
|
|||||||
and forward_batch.dp_padding_mode.is_max_len()
|
and forward_batch.dp_padding_mode.is_max_len()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def should_fuse_mlp_allreduce_with_next_layer(
|
||||||
|
self, forward_batch: ForwardBatch
|
||||||
|
) -> bool:
|
||||||
|
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
|
||||||
|
if (
|
||||||
|
is_dp_attention_enabled()
|
||||||
|
and speculative_algo is not None
|
||||||
|
and speculative_algo.is_eagle()
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
batch_size = (
|
||||||
|
forward_batch.input_ids.shape[0]
|
||||||
|
if hasattr(forward_batch, "input_ids")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
|
||||||
|
return False
|
||||||
|
|
||||||
|
static_conditions_met = (
|
||||||
|
(not self.is_last_layer)
|
||||||
|
and (self._context.tp_size > 1)
|
||||||
|
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||||
|
and _is_sm100_supported
|
||||||
|
and _is_flashinfer_available
|
||||||
|
)
|
||||||
|
|
||||||
|
if not static_conditions_met:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (
|
||||||
|
batch_size > 0
|
||||||
|
and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
|
||||||
|
and (not self.is_last_layer)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommunicateContext:
|
class CommunicateContext:
|
||||||
|
|||||||
@@ -1852,10 +1852,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
allow_reduce_scatter=True,
|
allow_reduce_scatter=True,
|
||||||
|
is_last_layer=(
|
||||||
|
is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
|
||||||
|
|
||||||
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
||||||
return is_nextn or (
|
return is_nextn or (
|
||||||
self.config.n_routed_experts is not None
|
self.config.n_routed_experts is not None
|
||||||
@@ -1863,20 +1864,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
and layer_id % self.config.moe_layer_freq == 0
|
and layer_id % self.config.moe_layer_freq == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
|
||||||
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
|
||||||
|
|
||||||
batch_size = (
|
|
||||||
forward_batch.input_ids.shape[0]
|
|
||||||
if hasattr(forward_batch, "input_ids")
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_size > 128:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -1902,11 +1889,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
should_allreduce_fusion = (
|
should_allreduce_fusion = (
|
||||||
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
||||||
and not (
|
forward_batch
|
||||||
is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
|
|
||||||
)
|
)
|
||||||
and not self.is_nextn
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
||||||
@@ -1997,26 +1982,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _build_fuse_allreduce_lookup_table(self):
|
|
||||||
static_conditions_met = (
|
|
||||||
self.layer_id != self.config.num_hidden_layers - 1
|
|
||||||
and get_tensor_model_parallel_world_size() > 1
|
|
||||||
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
|
||||||
and _is_sm100_supported
|
|
||||||
and _is_flashinfer_available
|
|
||||||
)
|
|
||||||
|
|
||||||
if not static_conditions_met:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
lookup_table = {}
|
|
||||||
for batch_size in range(129): # 0 to 128
|
|
||||||
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
|
||||||
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
|
||||||
lookup_table[batch_size] = should_fuse
|
|
||||||
|
|
||||||
return lookup_table
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Model(nn.Module):
|
class DeepseekV2Model(nn.Module):
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
|
|||||||
@@ -453,44 +453,11 @@ class GptOssDecoderLayer(nn.Module):
|
|||||||
layer_scatter_modes=self.layer_scatter_modes,
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
|
is_last_layer=(
|
||||||
|
self.is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
|
||||||
|
|
||||||
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
|
||||||
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
|
||||||
|
|
||||||
batch_size = (
|
|
||||||
forward_batch.input_ids.shape[0]
|
|
||||||
if hasattr(forward_batch, "input_ids")
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_size > 128:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
|
||||||
|
|
||||||
def _build_fuse_allreduce_lookup_table(self):
|
|
||||||
static_conditions_met = (
|
|
||||||
self.layer_id != self.config.num_hidden_layers - 1
|
|
||||||
and get_tensor_model_parallel_world_size() > 1
|
|
||||||
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
|
||||||
and _is_sm100_supported
|
|
||||||
and _is_flashinfer_available
|
|
||||||
)
|
|
||||||
|
|
||||||
if not static_conditions_met:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
lookup_table = {}
|
|
||||||
for batch_size in range(129): # 0 to 128
|
|
||||||
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
|
||||||
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
|
||||||
lookup_table[batch_size] = should_fuse
|
|
||||||
|
|
||||||
return lookup_table
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -514,8 +481,9 @@ class GptOssDecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
should_allreduce_fusion = (
|
should_allreduce_fusion = (
|
||||||
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
||||||
and not self.is_nextn
|
forward_batch
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
|
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
|
||||||
|
|||||||
Reference in New Issue
Block a user