Fix slow fused add RMSNorm (#10141)
This commit is contained in:
@@ -39,12 +39,8 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
fused_add_rmsnorm,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
)
|
||||
from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm
|
||||
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
|
||||
|
||||
if _use_aiter:
|
||||
from aiter import rmsnorm2d_fwd as rms_norm
|
||||
@@ -86,7 +82,9 @@ class RMSNorm(CustomOp):
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
flashinfer_fused_add_rmsnorm(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
return x, residual
|
||||
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user