### What this PR does / why we need it? Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. backport: https://github.com/vllm-project/vllm-ascend/pull/7482 Signed-off-by: w30012745 <wangxiaoshuai2@h-partners.com> Co-authored-by: w30012745 <wangxiaoshuai2@h-partners.com>
67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fla/ops/l2norm.py
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
|
|
|
|
|
@triton.jit(do_not_specialize=["eps", "M", "NUM_CHUNKS"])
|
|
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS):
|
|
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
|
|
rindex = tl.arange(0, N)[None, :]
|
|
|
|
for chunk in range(NUM_CHUNKS):
|
|
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
|
|
xmask = row_idx < M
|
|
|
|
xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32)
|
|
square = xs * xs
|
|
square_sum = tl.sum(square, 1)[:, None]
|
|
rsqrt = tl.rsqrt(square_sum + eps)
|
|
|
|
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
|
|
|
|
|
def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None):
|
|
x_shape_og = x.shape
|
|
x = x.reshape(-1, x.shape[-1])
|
|
# allocate output
|
|
if output_dtype is None:
|
|
y = torch.empty_like(x)
|
|
else:
|
|
y = torch.empty_like(x, dtype=output_dtype)
|
|
assert y.stride(-1) == 1
|
|
T, D = x.shape[0], x.shape[-1]
|
|
# Less than 64KB per feature: enqueue fused kernel
|
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
|
if D > BD:
|
|
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
|
|
|
MBLOCK = 69
|
|
# M, N = x.shape
|
|
num_core = get_vectorcore_num()
|
|
main_bs = triton.cdiv(T, num_core)
|
|
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
|
|
grid = (num_core,)
|
|
l2norm_fwd_kernel2_loop[grid](
|
|
X=x,
|
|
Y=y,
|
|
eps=eps,
|
|
M=T,
|
|
N=D,
|
|
MBLOCK=MBLOCK,
|
|
NUM_CHUNKS=num_sub_blocks,
|
|
)
|
|
|
|
return y.view(x_shape_og)
|