[Kernel] Optimize the recurrent op
This commit is contained in:
@@ -44,6 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
|||||||
h0_indices=ssm_state_indices,
|
h0_indices=ssm_state_indices,
|
||||||
num_accepted_tokens=num_accepted_tokens,
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||||
|
is_h0_transposed=True
|
||||||
)
|
)
|
||||||
return o, final_state
|
return o, final_state
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user