diff --git a/vllm_kunlun/ops/fla/fused_recurrent.py b/vllm_kunlun/ops/fla/fused_recurrent.py index 143b6a0..3902bee 100644 --- a/vllm_kunlun/ops/fla/fused_recurrent.py +++ b/vllm_kunlun/ops/fla/fused_recurrent.py @@ -44,6 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function): h0_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + is_h0_transposed=True ) return o, final_state @@ -150,4 +151,4 @@ def fused_recurrent_gated_delta_rule( num_accepted_tokens, use_qk_l2norm_in_kernel, ) - return o, final_state \ No newline at end of file + return o, final_state