[Kernel] Optimize the recurrent op
This commit is contained in:
@@ -616,6 +616,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
initial_state = ssm_state[
|
initial_state = ssm_state[
|
||||||
non_spec_state_indices_tensor].contiguous()
|
non_spec_state_indices_tensor].contiguous()
|
||||||
initial_state[~has_initial_state, ...] = 0
|
initial_state[~has_initial_state, ...] = 0
|
||||||
|
initial_state = initial_state.transpose(-1, -2).contiguous()
|
||||||
if self.num_v_heads // self.num_k_heads > 1:
|
if self.num_v_heads // self.num_k_heads > 1:
|
||||||
query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
@@ -634,6 +635,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
cu_seqlens=non_spec_query_start_loc,
|
cu_seqlens=non_spec_query_start_loc,
|
||||||
)
|
)
|
||||||
# Init cache
|
# Init cache
|
||||||
|
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous()
|
||||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
|
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
|
||||||
ssm_state.dtype)
|
ssm_state.dtype)
|
||||||
elif attn_metadata.num_decodes > 0:
|
elif attn_metadata.num_decodes > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user