Support trtllm_allreduce_fusion in flashinfer for cuda<12.8 (#9339)
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -202,8 +203,14 @@ class RMSNorm(CustomOp):
|
||||
flashinfer_allreduce_residual_rmsnorm,
|
||||
)
|
||||
|
||||
fused_op = (
|
||||
torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
|
||||
if supports_custom_op()
|
||||
else flashinfer_allreduce_residual_rmsnorm
|
||||
)
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
fused_result = flashinfer_allreduce_residual_rmsnorm(
|
||||
fused_result = fused_op(
|
||||
input_tensor=x,
|
||||
residual=residual,
|
||||
weight=self.weight,
|
||||
|
||||
Reference in New Issue
Block a user