From b97c78130042f6012e70d242fcaf79a8f12ce79e Mon Sep 17 00:00:00 2001 From: ldh2020 <62470572+ldh2020@users.noreply.github.com> Date: Sun, 21 Dec 2025 11:22:06 +0800 Subject: [PATCH] [Kernel] Optimize the recurrent op --- vllm_kunlun/models/qwen3_next.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_kunlun/models/qwen3_next.py b/vllm_kunlun/models/qwen3_next.py index d8c0aac..07ef498 100644 --- a/vllm_kunlun/models/qwen3_next.py +++ b/vllm_kunlun/models/qwen3_next.py @@ -616,6 +616,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): initial_state = ssm_state[ non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 + initial_state = initial_state.transpose(-1, -2).contiguous() if self.num_v_heads // self.num_k_heads > 1: query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) @@ -634,6 +635,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): cu_seqlens=non_spec_query_start_loc, ) # Init cache + last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous() ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( ssm_state.dtype) elif attn_metadata.num_decodes > 0: