35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
from vllm_ascend.ops.triton.fla.l2norm import l2norm_fwd
|
||
|
|
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(
|
||
|
|
('B', 'T', 'H', 'D', 'dtype'),
|
||
|
|
[
|
||
|
|
pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test))
|
||
|
|
for test in [
|
||
|
|
(1, 63, 1, 60, torch.float),
|
||
|
|
(2, 500, 4, 64, torch.float),
|
||
|
|
(2, 1000, 2, 100, torch.float),
|
||
|
|
(3, 1024, 4, 128, torch.float),
|
||
|
|
]
|
||
|
|
],
|
||
|
|
)
|
||
|
|
def test_l2norm(B: int, T: int, H: int, D: int, dtype: torch.dtype):
|
||
|
|
torch.manual_seed(42)
|
||
|
|
init_device_properties_triton()
|
||
|
|
device = "npu"
|
||
|
|
rtol, atol = (3e-4, 1e-3) if dtype == torch.float32 else (3e-3, 5e-3)
|
||
|
|
if dtype == torch.bfloat16:
|
||
|
|
rtol, atol = 1e-2, 5e-2
|
||
|
|
x = torch.randn(B, T, H, D, dtype=dtype).to(device).requires_grad_(True)
|
||
|
|
x = x * 0.5 + 0.3
|
||
|
|
|
||
|
|
ref = F.normalize(x, dim=-1, p=2)
|
||
|
|
tri = l2norm_fwd(x)
|
||
|
|
|
||
|
|
assert torch.allclose(tri, ref, rtol=rtol, atol=atol)
|