[Kernel] add l2norm triton kernel (#4595)
### What this PR does / why we need it?
This pull request introduces an L2 normalization kernel implemented in
Triton, specifically optimized for Ascend NPUs.
### Does this PR introduce _any_ user-facing change?
No, this PR does not introduce any user-facing changes.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
bc0a5a0c08
---------
Signed-off-by: Ascendyh <hw7osiris@outlook.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
70
vllm_ascend/ops/triton/fla/l2norm.py
Normal file
70
vllm_ascend/ops/triton/fla/l2norm.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fla/ops/l2norm.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
|
||||
MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
|
||||
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
|
||||
for chunk in range(NUM_CHUNKS):
|
||||
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
|
||||
xs = tl.load(X + (rindex + N * row_idx), mask=xmask,
|
||||
other=0.0).to(tl.float32)
|
||||
square = xs * xs
|
||||
square_sum = tl.sum(square, 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: torch.dtype | None = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
MBLOCK = 69
|
||||
# M, N = x.shape
|
||||
num_core = get_vectorcore_num()
|
||||
main_bs = triton.cdiv(T, num_core)
|
||||
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
|
||||
grid = (num_core, )
|
||||
l2norm_fwd_kernel2_loop[grid](
|
||||
X=x,
|
||||
Y=y,
|
||||
eps=eps,
|
||||
M=T,
|
||||
N=D,
|
||||
MBLOCK=MBLOCK,
|
||||
NUM_CHUNKS=num_sub_blocks,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
Reference in New Issue
Block a user