diff --git a/vllm_kunlun/ops/fla/chunk.py b/vllm_kunlun/ops/fla/chunk.py index 2da888e..20a039f 100644 --- a/vllm_kunlun/ops/fla/chunk.py +++ b/vllm_kunlun/ops/fla/chunk.py @@ -23,8 +23,10 @@ from .l2norm import l2norm_fwd from .solve_tril import solve_tril from .utils import SUPPRESS_LEVEL, input_guard from .wy_fast import recompute_w_u_fwd - +from .index import prepare_chunk_indices import xspeedgate_ops +import cocopod + def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,): chunk_size=64 @@ -59,13 +61,17 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor, #kernel版 torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens) - w, u = recompute_w_u_fwd( + chunk_indices = prepare_chunk_indices( + cu_seqlens, 64) if cu_seqlens is not None else None + w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd( k=k, v=v, beta=beta, A=A, g_cumsum=g, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_size=64 ) h, v_new, final_state = chunk_gated_delta_rule_fwd_h( k=k, @@ -76,7 +82,7 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor, output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) - o = chunk_fwd_o( + o = torch.ops.xspeedgate_ops.chunk_fwd_o( q=q, k=k, v=v_new, @@ -84,6 +90,8 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor, g=g, scale=scale, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_size=64 ) if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None