[Kernel] Optimize the recurrent op

This commit is contained in:
ldh2020
2025-12-21 11:18:00 +08:00
committed by GitHub
parent 58c1db5073
commit 004e164bdb

View File

@@ -44,6 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
h0_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
is_h0_transposed=True
)
return o, final_state
@@ -150,4 +151,4 @@ def fused_recurrent_gated_delta_rule(
num_accepted_tokens,
use_qk_l2norm_in_kernel,
)
return o, final_state
return o, final_state