From 8d114f254b0ebff665b31bb814cc30cf8aee1c66 Mon Sep 17 00:00:00 2001 From: sogalin <39478626+sogalin@users.noreply.github.com> Date: Sat, 6 Sep 2025 11:45:13 +0800 Subject: [PATCH] Fix RMSNorm API CALL mismatch issue. (#10032) Co-authored-by: Hubert Lu --- python/sglang/srt/layers/layernorm.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index cf8ccf4d1..7743b888e 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn +from packaging.version import Version from sglang.srt.custom_op import CustomOp from sglang.srt.utils import ( @@ -49,8 +50,11 @@ if _use_aiter: from aiter import rmsnorm2d_fwd as rms_norm from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm elif _is_hip: + import vllm from vllm._custom_ops import fused_add_rms_norm, rms_norm + _vllm_version = Version(vllm.__version__) + logger = logging.getLogger(__name__) if _is_npu: @@ -127,8 +131,21 @@ class RMSNorm(CustomOp): # NOTE: Remove this if aiter kernel supports discontinuous input x = x.contiguous() if residual is not None: - fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) - return x, residual + if _vllm_version < Version("0.9"): + fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + else: + residual_out = torch.empty_like(x) + output = torch.empty_like(x) + fused_add_rms_norm( + output, + x, + residual_out, + residual, + self.weight.data, + self.variance_epsilon, + ) + return output, residual_out out = torch.empty_like(x) rms_norm(out, x, self.weight.data, self.variance_epsilon) return out