[Kernel] Optimize the recurrent op

This commit is contained in:
ldh2020
2025-12-21 11:22:06 +08:00
committed by GitHub
parent 004e164bdb
commit b97c781300

View File

@@ -616,6 +616,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
initial_state = ssm_state[ initial_state = ssm_state[
non_spec_state_indices_tensor].contiguous() non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0 initial_state[~has_initial_state, ...] = 0
initial_state = initial_state.transpose(-1, -2).contiguous()
if self.num_v_heads // self.num_k_heads > 1: if self.num_v_heads // self.num_k_heads > 1:
query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
@@ -634,6 +635,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
cu_seqlens=non_spec_query_start_loc, cu_seqlens=non_spec_query_start_loc,
) )
# Init cache # Init cache
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous()
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
ssm_state.dtype) ssm_state.dtype)
elif attn_metadata.num_decodes > 0: elif attn_metadata.num_decodes > 0: