[ready b200] fuse allreduce+add_rmsnorm in prepare_attention + mlp module (#7775)

This commit is contained in:
Xiaoyu Zhang
2025-07-11 06:12:39 +08:00
committed by GitHub
parent 766392c6bd
commit 49a5915f53
3 changed files with 85 additions and 20 deletions

View File

@@ -187,11 +187,24 @@ class LayerCommunicator:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if (
residual is not None
and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
and hidden_states._sglang_needs_allreduce_fusion
):
hidden_states, residual = (
self.input_layernorm.forward_with_allreduce_fusion(
hidden_states, residual
)
)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
hidden_states = self._communicate_simple_fn(
hidden_states=hidden_states,