diff --git a/vllm_kunlun/ops/fla/l2norm.py b/vllm_kunlun/ops/fla/l2norm.py index ef9788c..8f0ed8b 100644 --- a/vllm_kunlun/ops/fla/l2norm.py +++ b/vllm_kunlun/ops/fla/l2norm.py @@ -11,9 +11,11 @@ import os from typing import Optional import torch - from vllm.triton_utils import tl, triton +import xtorch_ops + + BT_LIST = [8, 16, 32, 64, 128] USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) @@ -85,7 +87,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) -def l2norm_fwd(x: torch.Tensor, +def l2norm_fwd_triton(x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): x_shape_og = x.shape @@ -141,3 +143,11 @@ def l2norm_fwd(x: torch.Tensor, ) return y.view(x_shape_og) + + +def l2norm_fwd(x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None): + out = torch.empty_like(x) + xtorch_ops.l2norm(x, out, eps) + return out