[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)

Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
chanzhennan
2026-02-28 11:15:50 +08:00
committed by GitHub
parent 153093d3b3
commit 82544aa0cc
17 changed files with 2668 additions and 1532 deletions

View File

@@ -10,22 +10,21 @@
import os
from typing import Optional
import kunlun_ops
import torch
from vllm.triton_utils import tl, triton
import kunlun_ops
BT_LIST = [8, 16, 32, 64, 128]
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
@triton.autotune(configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
],
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
x,
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
tl.store(y + cols, b_y, mask=mask)
@triton.autotune(configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({"BT": BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in BT_LIST
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd_triton(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
if not USE_DEFAULT_FLA_NORM:
MBLOCK = 32
# M, N = x.shape
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
x,
y,
eps,
T,
D,
MBLOCK,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta):
return (triton.cdiv(T, meta['BT']), )
l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T, )](
x,
y,
eps=eps,
D=D,
BD=BD,
)
return y.view(x_shape_og)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
):
out = torch.empty_like(x)
kunlun_ops.l2norm(x, out, eps)
kunlun_ops.l2norm(x, out, eps)
return out