[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com> Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -9,60 +9,196 @@
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
import torch.nn.functional as F
|
||||
|
||||
import cocopod # noqa
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
from .index import prepare_chunk_indices
|
||||
import xspeedgate_ops
|
||||
import cocopod
|
||||
|
||||
|
||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
||||
chunk_size=64
|
||||
A = -A.transpose(1,2)
|
||||
def torch_solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
):
|
||||
chunk_size = 64
|
||||
A = -A.transpose(1, 2)
|
||||
sequence_length = A.shape[-2]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
A = F.pad(A, (0, 0, 0, pad_size))
|
||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
|
||||
|
||||
# A = A.masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = A[..., i, :i].clone()
|
||||
sub = A[..., :i, :i].clone()
|
||||
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
|
||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
|
||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
|
||||
:, :, :sequence_length, :
|
||||
].transpose(1, 2)
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=q.dtype)
|
||||
|
||||
#kernel版
|
||||
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, 64) if cu_seqlens is not None else None
|
||||
def recompute_w_u_fwd_torch(
|
||||
k: torch.Tensor, # [B, T, H, K]
|
||||
v: torch.Tensor, # [B, T, H, V]
|
||||
beta: torch.Tensor, # [B, T, H]
|
||||
g: torch.Tensor, # [B, T, H]
|
||||
A: torch.Tensor, # [B, H, T, T]
|
||||
):
|
||||
"""
|
||||
最简单版本:假设等长序列,key和value头数相同
|
||||
"""
|
||||
chunk_size = 64
|
||||
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
k, v, beta, g, A = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
|
||||
]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = k.shape
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
k = F.pad(k, (0, 0, 0, pad_size))
|
||||
v = F.pad(v, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
A = F.pad(A, (0, 0, 0, pad_size))
|
||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||
|
||||
v_beta = v * beta.unsqueeze(-1)
|
||||
k_beta = k * beta.unsqueeze(-1)
|
||||
|
||||
k, v, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (k, v, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
|
||||
u = A @ v_beta
|
||||
w = A @ (k_beta * g.exp().unsqueeze(-1))
|
||||
w = (
|
||||
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
u = (
|
||||
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
return w, u
|
||||
|
||||
|
||||
def split_by_value(tensor, chunk_size=64):
|
||||
indices = tensor.tolist()
|
||||
result = set(indices) # 使用集合避免重复
|
||||
|
||||
for i in range(len(indices) - 1):
|
||||
start = indices[i]
|
||||
end = indices[i + 1]
|
||||
|
||||
# 计算第一个对齐边界
|
||||
# 我们要找的是 start + n*chunk_size,其中n是使结果大于start的最小整数
|
||||
first_boundary = start + chunk_size
|
||||
|
||||
# 在(start, end)范围内插入所有对齐边界
|
||||
boundary = first_boundary
|
||||
while boundary < end:
|
||||
result.add(boundary)
|
||||
boundary += chunk_size
|
||||
|
||||
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
|
||||
)
|
||||
chunk_offsets = (
|
||||
prepare_chunk_offsets(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# !
|
||||
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
|
||||
g,
|
||||
chunk_size=64,
|
||||
reverse=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
head_first=False,
|
||||
)
|
||||
|
||||
# !
|
||||
# A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
# beta=beta,
|
||||
# g_cumsum=g,
|
||||
# cu_seqlens=cu_seqlens,
|
||||
# output_dtype=q.dtype)
|
||||
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
|
||||
k, beta, g, cu_seqlens, chunk_indices, chunk_size
|
||||
)
|
||||
|
||||
# torch版
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# torch.save(A, "A_in")
|
||||
# torch.save(cu_seqlens, "cu_seqlens")
|
||||
# A2 = A.clone()
|
||||
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
|
||||
|
||||
# !
|
||||
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# err = torch.max(torch.abs(A - A2))
|
||||
# print("err", err)
|
||||
# if err > 1e-3:
|
||||
# raise
|
||||
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
# for i in range(len(cu_seqlens)-1):
|
||||
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
|
||||
|
||||
"""
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
for i in range(len(cu_seqlens)-1):
|
||||
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
|
||||
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
|
||||
|
||||
w_i, u_i = recompute_w_u_fwd_torch(
|
||||
k=k_i,
|
||||
v=v_i,
|
||||
beta=beta_i,
|
||||
A=A_i,
|
||||
g=g_i,
|
||||
)
|
||||
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
|
||||
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
|
||||
"""
|
||||
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_size=64
|
||||
chunk_size=64,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
"""
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
"""
|
||||
|
||||
# i
|
||||
# import os
|
||||
# if not os.path.exists("/qwen-next/in"):
|
||||
# os.makedirs("/qwen-next/in")
|
||||
# torch.save(k, "/qwen-next/in/k.pt")
|
||||
# torch.save(u, "/qwen-next/in/u.pt")
|
||||
# torch.save(w, "/qwen-next/in/w.pt")
|
||||
# torch.save(g, "/qwen-next/in/g.pt")
|
||||
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
|
||||
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
|
||||
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
|
||||
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
|
||||
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
|
||||
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
|
||||
|
||||
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
|
||||
k,
|
||||
u,
|
||||
w,
|
||||
g,
|
||||
initial_state,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
chunk_offsets.to(torch.int32),
|
||||
chunk_size,
|
||||
output_final_state,
|
||||
True,
|
||||
)
|
||||
|
||||
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
# k=k,
|
||||
# w=w,
|
||||
# u=u,
|
||||
# g=g,
|
||||
# initial_state=initial_state,
|
||||
# output_final_state=output_final_state,
|
||||
# cu_seqlens=cu_seqlens,
|
||||
# )
|
||||
# if not os.path.exists("/qwen-next/out"):
|
||||
# os.makedirs("/qwen-next/out")
|
||||
# torch.save(h, "/qwen-next/out/h.pt")
|
||||
# torch.save(v_new, "/qwen-next/out/v_new.pt")
|
||||
# torch.save(final_state, "/qwen-next/out/final_state.pt")
|
||||
|
||||
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_size=64
|
||||
chunk_size=64,
|
||||
)
|
||||
"""
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
"""
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@torch.amp.custom_fwd(device_type='cuda')
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
assert (
|
||||
q.dtype != torch.float32
|
||||
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert (
|
||||
len(beta.shape) == 3
|
||||
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
|
||||
)
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
if False:
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
g = g.contiguous()
|
||||
beta = beta.contiguous()
|
||||
initial_state = initial_state.contiguous()
|
||||
|
||||
o = torch.empty_like(v)
|
||||
final_state = torch.empty_like(initial_state)
|
||||
import kunlun_ops
|
||||
|
||||
kunlun_ops.gated_delta_rule(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
initial_state,
|
||||
g,
|
||||
beta,
|
||||
final_state,
|
||||
o,
|
||||
scale,
|
||||
cu_seqlens.cpu(),
|
||||
cu_seqlens,
|
||||
cu_seqlens.cpu(),
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
else:
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
o = rearrange(o, "b t h ... -> b h t ...")
|
||||
return o, final_state
|
||||
|
||||
Reference in New Issue
Block a user