[Kernel] Optimize the recurrent op
This commit is contained in:
@@ -44,6 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
h0_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
is_h0_transposed=True
|
||||
)
|
||||
return o, final_state
|
||||
|
||||
@@ -150,4 +151,4 @@ def fused_recurrent_gated_delta_rule(
|
||||
num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
return o, final_state
|
||||
|
||||
Reference in New Issue
Block a user