Fix illegal memory in trtllm allreduce fusion (#7864)

This commit is contained in:
Xiaoyu Zhang
2025-07-09 02:47:17 +08:00
committed by GitHub
parent 51ae40306a
commit 2e7ab862e3
3 changed files with 8 additions and 6 deletions

View File

@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
if residual is not None:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_add_rmsnorm,
flashinfer_allreduce_residual_rmsnorm,
)
if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_add_rmsnorm(
fused_result = flashinfer_allreduce_residual_rmsnorm(
input_tensor=x,
residual=residual,
weight=self.weight,