[b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621)

This commit is contained in:
Xiaoyu Zhang
2025-07-03 10:36:20 +08:00
committed by GitHub
parent 82f021e22e
commit 8e64140e35
5 changed files with 253 additions and 2 deletions

View File

@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
else:
return self.forward_native(x, residual)
def forward_with_allreduce_fusion(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward method with allreduce fusion, prioritizing flashinfer fused operations
"""
if residual is not None:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_add_rmsnorm,
)
if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_add_rmsnorm(
input_tensor=x,
residual=residual,
weight=self.weight,
eps=self.variance_epsilon,
)
if fused_result[0] is not None:
return fused_result
return self.forward(x, residual)
class GemmaRMSNorm(CustomOp):
def __init__(