[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com> Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -10,22 +10,21 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import kunlun_ops
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
import kunlun_ops
|
||||
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=['D'])
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=["D"],
|
||||
)
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BT': BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
|
||||
],
|
||||
key=['D'])
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BT": BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16]
|
||||
for BT in BT_LIST
|
||||
],
|
||||
key=["D"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["NB"])
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd_triton(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not USE_DEFAULT_FLA_NORM:
|
||||
MBLOCK = 32
|
||||
# M, N = x.shape
|
||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta['BT']), )
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T, )](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
def l2norm_fwd(
|
||||
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
out = torch.empty_like(x)
|
||||
kunlun_ops.l2norm(x, out, eps)
|
||||
kunlun_ops.l2norm(x, out, eps)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user