[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com> Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -11,7 +11,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
@@ -28,6 +27,7 @@ RESOLUTION = {
|
||||
torch.complex64: 1.3e-6,
|
||||
}
|
||||
|
||||
|
||||
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||
assert res.dtype == dtype
|
||||
ref = ref.to(dtype)
|
||||
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||
rtol = RESOLUTION[dtype]
|
||||
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
recompute_w_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||
k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
|
||||
)
|
||||
return w, u
|
||||
return w, u
|
||||
|
||||
Reference in New Issue
Block a user