[b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621)
This commit is contained in:
@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
|
||||
else:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
def forward_with_allreduce_fusion(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Forward method with allreduce fusion, prioritizing flashinfer fused operations
|
||||
"""
|
||||
if residual is not None:
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.flashinfer_comm_fusion import (
|
||||
flashinfer_allreduce_add_rmsnorm,
|
||||
)
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
fused_result = flashinfer_allreduce_add_rmsnorm(
|
||||
input_tensor=x,
|
||||
residual=residual,
|
||||
weight=self.weight,
|
||||
eps=self.variance_epsilon,
|
||||
)
|
||||
if fused_result[0] is not None:
|
||||
return fused_result
|
||||
|
||||
return self.forward(x, residual)
|
||||
|
||||
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user