[Kernel] Qwen3-next 优化 recompute_w_u_fwd & chunk_fwd_o (#74)

Co-authored-by: yuanjizhong <yuanjizhong@baidu.com>
This commit is contained in:
callmelaoyi
2026-01-05 10:24:51 +08:00
committed by GitHub
parent fe666fb24f
commit b86953acf9

View File

@@ -23,8 +23,10 @@ from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
from .index import prepare_chunk_indices
import xspeedgate_ops
import cocopod
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
chunk_size=64
@@ -59,13 +61,17 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
#kernel版
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
w, u = recompute_w_u_fwd(
chunk_indices = prepare_chunk_indices(
cu_seqlens, 64) if cu_seqlens is not None else None
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
@@ -76,7 +82,7 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q=q,
k=k,
v=v_new,
@@ -84,6 +90,8 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
)
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None