[Kernel] Qwen3-next 优化 recompute_w_u_fwd & chunk_fwd_o (#74)
Co-authored-by: yuanjizhong <yuanjizhong@baidu.com>
This commit is contained in:
@@ -23,8 +23,10 @@ from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
import xspeedgate_ops
|
||||
import cocopod
|
||||
|
||||
|
||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
||||
chunk_size=64
|
||||
@@ -59,13 +61,17 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
|
||||
#kernel版
|
||||
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||
w, u = recompute_w_u_fwd(
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, 64) if cu_seqlens is not None else None
|
||||
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_size=64
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
@@ -76,7 +82,7 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
@@ -84,6 +90,8 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_size=64
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
|
||||
Reference in New Issue
Block a user