Files
xc-llm-ascend/tests/e2e/nightly/ops/triton/test_l2norm.py
Ascendyh a90482803d [Kernel] add l2norm triton kernel (#4595)
### What this PR does / why we need it?
This pull request introduces an L2 normalization kernel implemented in
Triton, specifically optimized for Ascend NPUs.
### Does this PR introduce _any_ user-facing change?
No, this PR does not introduce any user-facing changes.
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
bc0a5a0c08

---------

Signed-off-by: Ascendyh <hw7osiris@outlook.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-25 06:06:18 +08:00

35 lines
1.0 KiB
Python

import pytest
import torch
import torch.nn.functional as F
from vllm_ascend.ops.triton.fla.l2norm import l2norm_fwd
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
@pytest.mark.parametrize(
('B', 'T', 'H', 'D', 'dtype'),
[
pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(1, 63, 1, 60, torch.float),
(2, 500, 4, 64, torch.float),
(2, 1000, 2, 100, torch.float),
(3, 1024, 4, 128, torch.float),
]
],
)
def test_l2norm(B: int, T: int, H: int, D: int, dtype: torch.dtype):
torch.manual_seed(42)
init_device_properties_triton()
device = "npu"
rtol, atol = (3e-4, 1e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
x = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
x = x * 0.5 + 0.3
ref = F.normalize(x, dim=-1, p=2)
tri = l2norm_fwd(x)
assert torch.allclose(tri, ref, rtol=rtol, atol=atol)