[Kernel] Use l2norm kernel op instead of triton op.
This commit is contained in:
@@ -11,9 +11,11 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
import xtorch_ops
|
||||||
|
|
||||||
|
|
||||||
BT_LIST = [8, 16, 32, 64, 128]
|
BT_LIST = [8, 16, 32, 64, 128]
|
||||||
|
|
||||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||||
@@ -85,7 +87,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
|||||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||||
|
|
||||||
|
|
||||||
def l2norm_fwd(x: torch.Tensor,
|
def l2norm_fwd_triton(x: torch.Tensor,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
output_dtype: Optional[torch.dtype] = None):
|
output_dtype: Optional[torch.dtype] = None):
|
||||||
x_shape_og = x.shape
|
x_shape_og = x.shape
|
||||||
@@ -141,3 +143,11 @@ def l2norm_fwd(x: torch.Tensor,
|
|||||||
)
|
)
|
||||||
|
|
||||||
return y.view(x_shape_og)
|
return y.view(x_shape_og)
|
||||||
|
|
||||||
|
|
||||||
|
def l2norm_fwd(x: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
output_dtype: Optional[torch.dtype] = None):
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
xtorch_ops.l2norm(x, out, eps)
|
||||||
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user