@@ -441,6 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
and _is_flashinfer_available
|
and _is_flashinfer_available
|
||||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||||
|
and hidden_states.shape[0] <= 128
|
||||||
):
|
):
|
||||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||||
hidden_states, residual
|
hidden_states, residual
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ def flashinfer_allreduce_residual_rmsnorm(
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
max_token_num: int = 128,
|
max_token_num: int = 128,
|
||||||
use_oneshot: Optional[bool] = None,
|
use_oneshot: bool = True,
|
||||||
trigger_completion_at_end: bool = False,
|
trigger_completion_at_end: bool = False,
|
||||||
fp32_acc: bool = False,
|
fp32_acc: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|||||||
Reference in New Issue
Block a user