Support trtllm_allreduce_fusion in flashinfer for cuda<12.8 (#9339)

Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
strgrb
2025-08-21 07:54:30 +08:00
committed by GitHub
parent 8f5b9910c1
commit 88fbc31b50
3 changed files with 37 additions and 3 deletions

View File

@@ -27,6 +27,7 @@ from sglang.srt.utils import (
is_cuda,
is_hip,
is_npu,
supports_custom_op,
)
_is_cuda = is_cuda()
@@ -202,8 +203,14 @@ class RMSNorm(CustomOp):
flashinfer_allreduce_residual_rmsnorm,
)
fused_op = (
torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
if supports_custom_op()
else flashinfer_allreduce_residual_rmsnorm
)
if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_residual_rmsnorm(
fused_result = fused_op(
input_tensor=x,
residual=residual,
weight=self.weight,