Fix illegal memory in trtllm allreduce fusion (#7864)
This commit is contained in:
@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
|
||||
if residual is not None:
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.flashinfer_comm_fusion import (
|
||||
flashinfer_allreduce_add_rmsnorm,
|
||||
flashinfer_allreduce_residual_rmsnorm,
|
||||
)
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
fused_result = flashinfer_allreduce_add_rmsnorm(
|
||||
fused_result = flashinfer_allreduce_residual_rmsnorm(
|
||||
input_tensor=x,
|
||||
residual=residual,
|
||||
weight=self.weight,
|
||||
|
||||
Reference in New Issue
Block a user