Refactor fused_add_rmsnorm import logic (#10207)
Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
is_npu,
|
||||
is_xpu,
|
||||
@@ -33,6 +34,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
@@ -41,7 +43,10 @@ _is_cpu = is_cpu()
|
||||
_is_xpu = is_xpu()
|
||||
|
||||
if _is_cuda:
|
||||
from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm
|
||||
if _is_flashinfer_available:
|
||||
from flashinfer.norm import fused_add_rmsnorm
|
||||
else:
|
||||
from sgl_kernel import fused_add_rmsnorm
|
||||
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
|
||||
|
||||
if _use_aiter:
|
||||
@@ -84,9 +89,7 @@ class RMSNorm(CustomOp):
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
if residual is not None:
|
||||
flashinfer_fused_add_rmsnorm(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user