From a90482803dc12ede67028d4b83e029fde48f1adf Mon Sep 17 00:00:00 2001 From: Ascendyh Date: Thu, 25 Dec 2025 06:06:18 +0800 Subject: [PATCH] [Kernel] add l2norm triton kernel (#4595) ### What this PR does / why we need it? This pull request introduces an L2 normalization kernel implemented in Triton, specifically optimized for Ascend NPUs. ### Does this PR introduce _any_ user-facing change? No, this PR does not introduce any user-facing changes. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bc0a5a0c089844b17cb93f3294348f411e523586 --------- Signed-off-by: Ascendyh Co-authored-by: Mengqing Cao --- .github/workflows/_e2e_test.yaml | 1 + tests/e2e/nightly/ops/triton/test_l2norm.py | 34 ++++++++++ vllm_ascend/ops/triton/fla/chunk.py | 2 +- vllm_ascend/ops/triton/fla/l2norm.py | 70 +++++++++++++++++++++ 4 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/nightly/ops/triton/test_l2norm.py create mode 100644 vllm_ascend/ops/triton/fla/l2norm.py diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index ce17c760..f5238c96 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -103,6 +103,7 @@ jobs: # We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run # the test separately. + pytest -sv --durations=0 tests/e2e/nightly/ops/triton/ pytest -sv --durations=0 tests/e2e/singlecard/test_completion_with_prompt_embeds.py pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_accuracy.py pytest -sv --durations=0 tests/e2e/singlecard/test_async_scheduling.py diff --git a/tests/e2e/nightly/ops/triton/test_l2norm.py b/tests/e2e/nightly/ops/triton/test_l2norm.py new file mode 100644 index 00000000..0b891468 --- /dev/null +++ b/tests/e2e/nightly/ops/triton/test_l2norm.py @@ -0,0 +1,34 @@ +import pytest +import torch +import torch.nn.functional as F + +from vllm_ascend.ops.triton.fla.l2norm import l2norm_fwd +from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (1, 63, 1, 60, torch.float), + (2, 500, 4, 64, torch.float), + (2, 1000, 2, 100, torch.float), + (3, 1024, 4, 128, torch.float), + ] + ], +) +def test_l2norm(B: int, T: int, H: int, D: int, dtype: torch.dtype): + torch.manual_seed(42) + init_device_properties_triton() + device = "npu" + rtol, atol = (3e-4, 1e-3) if dtype == torch.float32 else (3e-3, 5e-3) + if dtype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + x = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True) + x = x * 0.5 + 0.3 + + ref = F.normalize(x, dim=-1, p=2) + tri = l2norm_fwd(x) + + assert torch.allclose(tri, ref, rtol=rtol, atol=atol) diff --git a/vllm_ascend/ops/triton/fla/chunk.py b/vllm_ascend/ops/triton/fla/chunk.py index 2d3dade7..03d2d6cd 100644 --- a/vllm_ascend/ops/triton/fla/chunk.py +++ b/vllm_ascend/ops/triton/fla/chunk.py @@ -13,13 +13,13 @@ from typing import Optional import torch from einops import rearrange -from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL from .chunk_delta_h import chunk_gated_delta_rule_fwd_h from .chunk_o import chunk_fwd_o from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd from .solve_tril import solve_tril from .utils import input_guard from .wy_fast import recompute_w_u_fwd diff --git a/vllm_ascend/ops/triton/fla/l2norm.py b/vllm_ascend/ops/triton/fla/l2norm.py new file mode 100644 index 00000000..82c83247 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/l2norm.py @@ -0,0 +1,70 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fla/ops/l2norm.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +from vllm.triton_utils import tl, triton + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + + +@triton.jit +def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, + MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr): + base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK) + rindex = tl.arange(0, N)[None, :] + + for chunk in range(NUM_CHUNKS): + row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + + xs = tl.load(X + (rindex + N * row_idx), mask=xmask, + other=0.0).to(tl.float32) + square = xs * xs + square_sum = tl.sum(square, 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def l2norm_fwd(x: torch.Tensor, + eps: float = 1e-6, + output_dtype: torch.dtype | None = None): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + MBLOCK = 69 + # M, N = x.shape + num_core = get_vectorcore_num() + main_bs = triton.cdiv(T, num_core) + num_sub_blocks = triton.cdiv(main_bs, MBLOCK) + grid = (num_core, ) + l2norm_fwd_kernel2_loop[grid]( + X=x, + Y=y, + eps=eps, + M=T, + N=D, + MBLOCK=MBLOCK, + NUM_CHUNKS=num_sub_blocks, + ) + + return y.view(x_shape_og)