Revert "feat: update flashinfer ar oneshot params (#8687)" (#9054)

This commit is contained in:
Yineng Zhang
2025-08-10 23:24:42 -07:00
committed by GitHub
parent b32792516a
commit 9d834fdcc1
2 changed files with 2 additions and 1 deletions

View File

@@ -441,6 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 128
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual

View File

@@ -125,7 +125,7 @@ def flashinfer_allreduce_residual_rmsnorm(
weight: torch.Tensor,
eps: float = 1e-6,
max_token_num: int = 128,
use_oneshot: Optional[bool] = None,
use_oneshot: bool = True,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: