[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)

Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
chanzhennan
2026-02-28 11:15:50 +08:00
committed by GitHub
parent 153093d3b3
commit 82544aa0cc
17 changed files with 2668 additions and 1532 deletions

View File

@@ -9,60 +9,196 @@
# ruff: noqa: E501
import warnings
from typing import Optional
import torch.nn.functional as F
import cocopod # noqa
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
from .index import prepare_chunk_indices
import xspeedgate_ops
import cocopod
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
chunk_size=64
A = -A.transpose(1,2)
def torch_solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
output_dtype: torch.dtype = torch.float,
):
chunk_size = 64
A = -A.transpose(1, 2)
sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
# A = A.masked_fill(mask, 0)
for i in range(1, chunk_size):
row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone()
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
:, :, :sequence_length, :
].transpose(1, 2)
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = chunk_scaled_dot_kkt_fwd(k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
output_dtype=q.dtype)
#kernel版
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
chunk_indices = prepare_chunk_indices(
cu_seqlens, 64) if cu_seqlens is not None else None
def recompute_w_u_fwd_torch(
k: torch.Tensor, # [B, T, H, K]
v: torch.Tensor, # [B, T, H, V]
beta: torch.Tensor, # [B, T, H]
g: torch.Tensor, # [B, T, H]
A: torch.Tensor, # [B, H, T, T]
):
"""
最简单版本假设等长序列key和value头数相同
"""
chunk_size = 64
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k, v, beta, g, A = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
]
batch_size, num_heads, sequence_length, k_head_dim = k.shape
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
k = F.pad(k, (0, 0, 0, pad_size))
v = F.pad(v, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
v_beta = v * beta.unsqueeze(-1)
k_beta = k * beta.unsqueeze(-1)
k, v, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (k, v, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
u = A @ v_beta
w = A @ (k_beta * g.exp().unsqueeze(-1))
w = (
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
u = (
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
return w, u
def split_by_value(tensor, chunk_size=64):
indices = tensor.tolist()
result = set(indices) # 使用集合避免重复
for i in range(len(indices) - 1):
start = indices[i]
end = indices[i + 1]
# 计算第一个对齐边界
# 我们要找的是 start + n*chunk_size其中n是使结果大于start的最小整数
first_boundary = start + chunk_size
# 在(start, end)范围内插入所有对齐边界
boundary = first_boundary
while boundary < end:
result.add(boundary)
boundary += chunk_size
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
chunk_size = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
)
chunk_offsets = (
prepare_chunk_offsets(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# !
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
g,
chunk_size=64,
reverse=False,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
head_first=False,
)
# !
# A = chunk_scaled_dot_kkt_fwd(k=k,
# beta=beta,
# g_cumsum=g,
# cu_seqlens=cu_seqlens,
# output_dtype=q.dtype)
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
k, beta, g, cu_seqlens, chunk_indices, chunk_size
)
# torch版
# if get_tensor_model_parallel_rank() == 0:
# torch.save(A, "A_in")
# torch.save(cu_seqlens, "cu_seqlens")
# A2 = A.clone()
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
# !
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
# if get_tensor_model_parallel_rank() == 0:
# err = torch.max(torch.abs(A - A2))
# print("err", err)
# if err > 1e-3:
# raise
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
# for i in range(len(cu_seqlens)-1):
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
"""
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
for i in range(len(cu_seqlens)-1):
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
w_i, u_i = recompute_w_u_fwd_torch(
k=k_i,
v=v_i,
beta=beta_i,
A=A_i,
g=g_i,
)
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
"""
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k=k,
v=v,
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
chunk_size=64,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
"""
w, u = recompute_w_u_fwd(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
)
"""
# i
# import os
# if not os.path.exists("/qwen-next/in"):
# os.makedirs("/qwen-next/in")
# torch.save(k, "/qwen-next/in/k.pt")
# torch.save(u, "/qwen-next/in/u.pt")
# torch.save(w, "/qwen-next/in/w.pt")
# torch.save(g, "/qwen-next/in/g.pt")
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
k,
u,
w,
g,
initial_state,
cu_seqlens,
chunk_indices,
chunk_offsets.to(torch.int32),
chunk_size,
output_final_state,
True,
)
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
# k=k,
# w=w,
# u=u,
# g=g,
# initial_state=initial_state,
# output_final_state=output_final_state,
# cu_seqlens=cu_seqlens,
# )
# if not os.path.exists("/qwen-next/out"):
# os.makedirs("/qwen-next/out")
# torch.save(h, "/qwen-next/out/h.pt")
# torch.save(v_new, "/qwen-next/out/v_new.pt")
# torch.save(final_state, "/qwen-next/out/final_state.pt")
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q=q,
k=k,
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
chunk_size=64,
)
"""
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
"""
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False):
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@torch.compiler.disable
def chunk_gated_delta_rule(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False):
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
)
"""
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(
beta.shape
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
assert (
q.dtype != torch.float32
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert (
len(beta.shape) == 3
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2)
stacklevel=2,
)
q, k, v, beta, g = map(
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
(q, k, v, beta, g))
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1]**-0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
use_qk_l2norm_in_kernel)
scale = k.shape[-1] ** -0.5
if False:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
g = g.contiguous()
beta = beta.contiguous()
initial_state = initial_state.contiguous()
o = torch.empty_like(v)
final_state = torch.empty_like(initial_state)
import kunlun_ops
kunlun_ops.gated_delta_rule(
q,
k,
v,
initial_state,
g,
beta,
final_state,
o,
scale,
cu_seqlens.cpu(),
cu_seqlens,
cu_seqlens.cpu(),
cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
else:
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
if head_first:
o = rearrange(o, 'b t h ... -> b h t ...')
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state