From fce97df9089044811a6092c12cde91f499f30fd1 Mon Sep 17 00:00:00 2001 From: ldh2020 <62470572+ldh2020@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:24:47 +0800 Subject: [PATCH] [Kernel] Use l2norm kernel op instead of triton op. --- vllm_kunlun/ops/fla/l2norm.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm_kunlun/ops/fla/l2norm.py b/vllm_kunlun/ops/fla/l2norm.py index ef9788c..8f0ed8b 100644 --- a/vllm_kunlun/ops/fla/l2norm.py +++ b/vllm_kunlun/ops/fla/l2norm.py @@ -11,9 +11,11 @@ import os from typing import Optional import torch - from vllm.triton_utils import tl, triton +import xtorch_ops + + BT_LIST = [8, 16, 32, 64, 128] USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) @@ -85,7 +87,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) -def l2norm_fwd(x: torch.Tensor, +def l2norm_fwd_triton(x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): x_shape_og = x.shape @@ -141,3 +143,11 @@ def l2norm_fwd(x: torch.Tensor, ) return y.view(x_shape_og) + + +def l2norm_fwd(x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None): + out = torch.empty_like(x) + xtorch_ops.l2norm(x, out, eps) + return out