# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fla/ops/l2norm.py # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang # # This file contains code copied from the flash-linear-attention project. # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import torch from vllm.triton_utils import tl, triton from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num @triton.jit def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr): base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK) rindex = tl.arange(0, N)[None, :] for chunk in range(NUM_CHUNKS): row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None] xmask = row_idx < M xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32) square = xs * xs square_sum = tl.sum(square, 1)[:, None] rsqrt = tl.rsqrt(square_sum + eps) tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None): x_shape_og = x.shape x = x.reshape(-1, x.shape[-1]) # allocate output if output_dtype is None: y = torch.empty_like(x) else: y = torch.empty_like(x, dtype=output_dtype) assert y.stride(-1) == 1 T, D = x.shape[0], x.shape[-1] # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) if D > BD: raise RuntimeError("This layer doesn't support feature dim >= 64KB.") MBLOCK = 69 # M, N = x.shape num_core = get_vectorcore_num() main_bs = triton.cdiv(T, num_core) num_sub_blocks = triton.cdiv(main_bs, MBLOCK) grid = (num_core,) l2norm_fwd_kernel2_loop[grid]( X=x, Y=y, eps=eps, M=T, N=D, MBLOCK=MBLOCK, NUM_CHUNKS=num_sub_blocks, ) return y.view(x_shape_og)