Fix RMSNorm API CALL mismatch issue. (#10032)
Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -49,8 +50,11 @@ if _use_aiter:
|
|||||||
from aiter import rmsnorm2d_fwd as rms_norm
|
from aiter import rmsnorm2d_fwd as rms_norm
|
||||||
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
||||||
elif _is_hip:
|
elif _is_hip:
|
||||||
|
import vllm
|
||||||
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
||||||
|
|
||||||
|
_vllm_version = Version(vllm.__version__)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if _is_npu:
|
if _is_npu:
|
||||||
@@ -127,8 +131,21 @@ class RMSNorm(CustomOp):
|
|||||||
# NOTE: Remove this if aiter kernel supports discontinuous input
|
# NOTE: Remove this if aiter kernel supports discontinuous input
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
if _vllm_version < Version("0.9"):
|
||||||
return x, residual
|
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||||
|
return x, residual
|
||||||
|
else:
|
||||||
|
residual_out = torch.empty_like(x)
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
fused_add_rms_norm(
|
||||||
|
output,
|
||||||
|
x,
|
||||||
|
residual_out,
|
||||||
|
residual,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return output, residual_out
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user