Signed-off-by: xyDong0223 <dongxinyu03@baidu.com> Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
98 lines
2.7 KiB
Python
98 lines
2.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
|
|
import os
|
|
from typing import Optional
|
|
|
|
import kunlun_ops
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
BT_LIST = [8, 16, 32, 64, 128]
|
|
|
|
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
|
|
],
|
|
key=["D"],
|
|
)
|
|
@triton.jit
|
|
def l2norm_fwd_kernel1(
|
|
x,
|
|
y,
|
|
D,
|
|
BD: tl.constexpr,
|
|
eps,
|
|
):
|
|
i_t = tl.program_id(0)
|
|
x += i_t * D
|
|
y += i_t * D
|
|
# Compute mean and variance
|
|
cols = tl.arange(0, BD)
|
|
mask = cols < D
|
|
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
|
|
b_var = tl.sum(b_x * b_x, axis=0)
|
|
b_rstd = 1 / tl.sqrt(b_var + eps)
|
|
# tl.store(Rstd + i_t, rstd)
|
|
# Normalize and apply linear transformation
|
|
b_y = b_x * b_rstd
|
|
tl.store(y + cols, b_y, mask=mask)
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BT": BT}, num_warps=num_warps)
|
|
for num_warps in [1, 2, 4, 8, 16]
|
|
for BT in BT_LIST
|
|
],
|
|
key=["D"],
|
|
)
|
|
@triton.jit(do_not_specialize=["NB"])
|
|
def l2norm_fwd_kernel(
|
|
x,
|
|
y,
|
|
eps,
|
|
NB,
|
|
T,
|
|
D: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
BD: tl.constexpr,
|
|
):
|
|
i_t = tl.program_id(0)
|
|
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
|
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
|
b_var = tl.sum(b_x * b_x, axis=1)
|
|
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
|
|
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
|
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
@triton.jit
|
|
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
|
xoffset = tl.program_id(0) * MBLOCK
|
|
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
|
xmask = row_idx < M
|
|
rindex = tl.arange(0, N)[None, :]
|
|
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
|
|
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
|
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
|
rsqrt = tl.rsqrt(square_sum + eps)
|
|
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
|
|
|
|
|
def l2norm_fwd(
|
|
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
|
|
):
|
|
out = torch.empty_like(x)
|
|
kunlun_ops.l2norm(x, out, eps)
|
|
return out
|