[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)

Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
chanzhennan
2026-02-28 11:15:50 +08:00
committed by GitHub
parent 153093d3b3
commit 82544aa0cc
17 changed files with 2668 additions and 1532 deletions

View File

@@ -11,7 +11,6 @@
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
@@ -28,6 +27,7 @@ RESOLUTION = {
torch.complex64: 1.3e-6,
}
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
assert res.dtype == dtype
ref = ref.to(dtype)
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
rtol = RESOLUTION[dtype]
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
recompute_u_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
recompute_w_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
)
return w, u
return w, u