Support trtllm_allreduce_fusion in flashinfer for cuda<12.8 (#9339)
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
@@ -292,7 +292,6 @@ class LayerCommunicator:
|
|||||||
(not self.is_last_layer)
|
(not self.is_last_layer)
|
||||||
and (self._context.tp_size > 1)
|
and (self._context.tp_size > 1)
|
||||||
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||||
and _is_sm100_supported
|
|
||||||
and _is_flashinfer_available
|
and _is_flashinfer_available
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,11 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import (
|
||||||
|
direct_register_custom_op,
|
||||||
|
is_flashinfer_available,
|
||||||
|
supports_custom_op,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -196,6 +200,30 @@ def flashinfer_allreduce_residual_rmsnorm(
|
|||||||
return norm_out, residual_out
|
return norm_out, residual_out
|
||||||
|
|
||||||
|
|
||||||
|
def fake_flashinfer_allreduce_residual_rmsnorm(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
max_token_num: int = 2048,
|
||||||
|
use_oneshot: Optional[bool] = None,
|
||||||
|
trigger_completion_at_end: bool = False,
|
||||||
|
fp32_acc: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
residual_out = torch.empty_like(residual)
|
||||||
|
norm_out = torch.empty_like(input_tensor)
|
||||||
|
return norm_out, residual_out
|
||||||
|
|
||||||
|
|
||||||
|
if supports_custom_op():
|
||||||
|
direct_register_custom_op(
|
||||||
|
"flashinfer_allreduce_residual_rmsnorm",
|
||||||
|
flashinfer_allreduce_residual_rmsnorm,
|
||||||
|
mutates_args=["input_tensor", "residual", "weight"],
|
||||||
|
fake_impl=fake_flashinfer_allreduce_residual_rmsnorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_flashinfer_workspace():
|
def cleanup_flashinfer_workspace():
|
||||||
global _workspace_manager
|
global _workspace_manager
|
||||||
if _workspace_manager is not None:
|
if _workspace_manager is not None:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from sglang.srt.utils import (
|
|||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
is_npu,
|
is_npu,
|
||||||
|
supports_custom_op,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -202,8 +203,14 @@ class RMSNorm(CustomOp):
|
|||||||
flashinfer_allreduce_residual_rmsnorm,
|
flashinfer_allreduce_residual_rmsnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fused_op = (
|
||||||
|
torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
|
||||||
|
if supports_custom_op()
|
||||||
|
else flashinfer_allreduce_residual_rmsnorm
|
||||||
|
)
|
||||||
|
|
||||||
if get_tensor_model_parallel_world_size() > 1:
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
fused_result = flashinfer_allreduce_residual_rmsnorm(
|
fused_result = fused_op(
|
||||||
input_tensor=x,
|
input_tensor=x,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
|
|||||||
Reference in New Issue
Block a user