From 729b7edf726c827cf2706b684cbde83b08fbc189 Mon Sep 17 00:00:00 2001 From: "Huaiyu, Zheng" Date: Thu, 16 Oct 2025 08:54:18 +0800 Subject: [PATCH] enable rmsnorm on XPU (#10248) --- python/sglang/bench_one_batch.py | 21 ++++++----- python/sglang/srt/layers/layernorm.py | 50 ++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 213ef2715..d52da3a6e 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( configure_logger, get_bool_env_var, + is_cuda_alike, + is_xpu, kill_process_tree, require_mlp_sync, require_mlp_tp_gather, @@ -80,6 +82,15 @@ from sglang.srt.utils import ( ) from sglang.srt.utils.hf_transformers_utils import get_tokenizer +profile_activities = [torch.profiler.ProfilerActivity.CPU] + [ + profiler_activity + for available, profiler_activity in [ + (is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA), + (is_xpu(), torch.profiler.ProfilerActivity.XPU), + ] + if available +] + @dataclasses.dataclass class BenchArgs: @@ -424,10 +435,7 @@ def latency_test_run_once( profiler = None if profile: profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], + activities=profile_activities, with_stack=True, record_shapes=profile_record_shapes, ) @@ -460,10 +468,7 @@ def latency_test_run_once( if profile and i == output_len / 2: profiler = None profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], + activities=profile_activities, with_stack=True, record_shapes=profile_record_shapes, ) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 399ef3e71..a0b75780b 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -52,8 +52,13 @@ if _is_cuda: gemma_rmsnorm, rmsnorm, ) - - +elif _is_xpu: + from sgl_kernel import ( + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + rmsnorm, + ) if _use_aiter: from aiter import rmsnorm2d_fwd as rms_norm from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm @@ -216,6 +221,19 @@ class RMSNorm(CustomOp): else: return self.forward_native(x, residual) + def forward_xpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + if residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = rmsnorm(x, self.weight.data, self.variance_epsilon) + return out + def forward_with_allreduce_fusion( self, x: torch.Tensor, @@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp): if _is_hip: self._forward_method = self.forward_native + def _forward_impl( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + gemma_fused_add_rmsnorm( + x, residual, self.weight.data, self.variance_epsilon + ) + return x, residual + out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) + return out + def forward_native( self, x: torch.Tensor, @@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp): x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if residual is not None: - gemma_fused_add_rmsnorm( - x, residual, self.weight.data, self.variance_epsilon - ) - return x, residual - out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) - return out + return self._forward_impl(x, residual) def forward_npu( self, @@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp): x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) return x if residual is None else (x, residual) + def forward_xpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self._forward_impl(x, residual) + class Gemma3RMSNorm(CustomOp): def __init__(self, dim: int, eps: float = 1e-6):