fuse allreduce and residual_rmsnorm (#8731)

This commit is contained in:
Xiaoyu Zhang
2025-08-12 04:50:53 +08:00
committed by GitHub
parent 8c07fabda7
commit 44e86480e8
8 changed files with 135 additions and 59 deletions

View File

@@ -441,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 128
and hidden_states.shape[0] <= 2048
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual