Files
xc-llm-ascend/tests/e2e/nightly/ops/triton/test_l2norm.py

35 lines
1.0 KiB
Python
Raw Normal View History

import pytest
import torch
import torch.nn.functional as F
from vllm_ascend.ops.triton.fla.l2norm import l2norm_fwd
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
@pytest.mark.parametrize(
('B', 'T', 'H', 'D', 'dtype'),
[
pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(1, 63, 1, 60, torch.float),
(2, 500, 4, 64, torch.float),
(2, 1000, 2, 100, torch.float),
(3, 1024, 4, 128, torch.float),
]
],
)
def test_l2norm(B: int, T: int, H: int, D: int, dtype: torch.dtype):
torch.manual_seed(42)
init_device_properties_triton()
device = "npu"
rtol, atol = (3e-4, 1e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
x = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
x = x * 0.5 + 0.3
ref = F.normalize(x, dim=-1, p=2)
tri = l2norm_fwd(x)
assert torch.allclose(tri, ref, rtol=rtol, atol=atol)