[ready b200] fuse allreduce+add_rmsnorm in prepare_attention + mlp module (#7775)
This commit is contained in:
@@ -187,11 +187,24 @@ class LayerCommunicator:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
if (
|
||||
residual is not None
|
||||
and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
|
||||
and hidden_states._sglang_needs_allreduce_fusion
|
||||
):
|
||||
hidden_states, residual = (
|
||||
self.input_layernorm.forward_with_allreduce_fusion(
|
||||
hidden_states, residual
|
||||
)
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
hidden_states = self._communicate_simple_fn(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user