[Kernel] add l2norm triton kernel (#4595)
### What this PR does / why we need it?
This pull request introduces an L2 normalization kernel implemented in
Triton, specifically optimized for Ascend NPUs.
### Does this PR introduce _any_ user-facing change?
No, this PR does not introduce any user-facing changes.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
bc0a5a0c08
---------
Signed-off-by: Ascendyh <hw7osiris@outlook.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -103,6 +103,7 @@ jobs:
|
|||||||
# We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run
|
# We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run
|
||||||
# the test separately.
|
# the test separately.
|
||||||
|
|
||||||
|
pytest -sv --durations=0 tests/e2e/nightly/ops/triton/
|
||||||
pytest -sv --durations=0 tests/e2e/singlecard/test_completion_with_prompt_embeds.py
|
pytest -sv --durations=0 tests/e2e/singlecard/test_completion_with_prompt_embeds.py
|
||||||
pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_accuracy.py
|
pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_accuracy.py
|
||||||
pytest -sv --durations=0 tests/e2e/singlecard/test_async_scheduling.py
|
pytest -sv --durations=0 tests/e2e/singlecard/test_async_scheduling.py
|
||||||
|
|||||||
34
tests/e2e/nightly/ops/triton/test_l2norm.py
Normal file
34
tests/e2e/nightly/ops/triton/test_l2norm.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm_ascend.ops.triton.fla.l2norm import l2norm_fwd
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
('B', 'T', 'H', 'D', 'dtype'),
|
||||||
|
[
|
||||||
|
pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test))
|
||||||
|
for test in [
|
||||||
|
(1, 63, 1, 60, torch.float),
|
||||||
|
(2, 500, 4, 64, torch.float),
|
||||||
|
(2, 1000, 2, 100, torch.float),
|
||||||
|
(3, 1024, 4, 128, torch.float),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_l2norm(B: int, T: int, H: int, D: int, dtype: torch.dtype):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
init_device_properties_triton()
|
||||||
|
device = "npu"
|
||||||
|
rtol, atol = (3e-4, 1e-3) if dtype == torch.float32 else (3e-3, 5e-3)
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
rtol, atol = 1e-2, 5e-2
|
||||||
|
x = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
|
||||||
|
x = x * 0.5 + 0.3
|
||||||
|
|
||||||
|
ref = F.normalize(x, dim=-1, p=2)
|
||||||
|
tri = l2norm_fwd(x)
|
||||||
|
|
||||||
|
assert torch.allclose(tri, ref, rtol=rtol, atol=atol)
|
||||||
@@ -13,13 +13,13 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
|
|
||||||
from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL
|
from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL
|
||||||
|
|
||||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||||
from .chunk_o import chunk_fwd_o
|
from .chunk_o import chunk_fwd_o
|
||||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||||
from .cumsum import chunk_local_cumsum
|
from .cumsum import chunk_local_cumsum
|
||||||
|
from .l2norm import l2norm_fwd
|
||||||
from .solve_tril import solve_tril
|
from .solve_tril import solve_tril
|
||||||
from .utils import input_guard
|
from .utils import input_guard
|
||||||
from .wy_fast import recompute_w_u_fwd
|
from .wy_fast import recompute_w_u_fwd
|
||||||
|
|||||||
70
vllm_ascend/ops/triton/fla/l2norm.py
Normal file
70
vllm_ascend/ops/triton/fla/l2norm.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fla/ops/l2norm.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||||
|
#
|
||||||
|
# This file contains code copied from the flash-linear-attention project.
|
||||||
|
# The original source code was licensed under the MIT license and included
|
||||||
|
# the following copyright notice:
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
|
||||||
|
MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
|
||||||
|
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
|
||||||
|
rindex = tl.arange(0, N)[None, :]
|
||||||
|
|
||||||
|
for chunk in range(NUM_CHUNKS):
|
||||||
|
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
|
||||||
|
xmask = row_idx < M
|
||||||
|
|
||||||
|
xs = tl.load(X + (rindex + N * row_idx), mask=xmask,
|
||||||
|
other=0.0).to(tl.float32)
|
||||||
|
square = xs * xs
|
||||||
|
square_sum = tl.sum(square, 1)[:, None]
|
||||||
|
rsqrt = tl.rsqrt(square_sum + eps)
|
||||||
|
|
||||||
|
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||||
|
|
||||||
|
|
||||||
|
def l2norm_fwd(x: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
output_dtype: torch.dtype | None = None):
|
||||||
|
x_shape_og = x.shape
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
# allocate output
|
||||||
|
if output_dtype is None:
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
else:
|
||||||
|
y = torch.empty_like(x, dtype=output_dtype)
|
||||||
|
assert y.stride(-1) == 1
|
||||||
|
T, D = x.shape[0], x.shape[-1]
|
||||||
|
# Less than 64KB per feature: enqueue fused kernel
|
||||||
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||||
|
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||||
|
if D > BD:
|
||||||
|
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||||
|
|
||||||
|
MBLOCK = 69
|
||||||
|
# M, N = x.shape
|
||||||
|
num_core = get_vectorcore_num()
|
||||||
|
main_bs = triton.cdiv(T, num_core)
|
||||||
|
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
|
||||||
|
grid = (num_core, )
|
||||||
|
l2norm_fwd_kernel2_loop[grid](
|
||||||
|
X=x,
|
||||||
|
Y=y,
|
||||||
|
eps=eps,
|
||||||
|
M=T,
|
||||||
|
N=D,
|
||||||
|
MBLOCK=MBLOCK,
|
||||||
|
NUM_CHUNKS=num_sub_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
return y.view(x_shape_og)
|
||||||
Reference in New Issue
Block a user