[Kernel] Optimize the selection and update OP of ssm state

This commit is contained in:
ldh2020
2025-12-21 15:45:32 +08:00
committed by GitHub
parent b97c781300
commit 8261a09e2a

View File

@@ -3,7 +3,7 @@
"""Inference-only Qwen3Next model.""" """Inference-only Qwen3Next model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice from itertools import islice
from typing import Optional from typing import Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -69,10 +69,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer,
maybe_prefix) maybe_prefix)
from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops.activation import SiluAndMul
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
from typing import Optional, Union
from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask
import xtorch_ops
@torch.compile(dynamic=True, backend="aot_eager") @torch.compile(dynamic=True, backend="aot_eager")
@@ -613,13 +611,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 3.2: process the remaining part # 3.2: process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
initial_state = ssm_state[ if non_spec_state_indices_tensor.shape[0] > 100:
non_spec_state_indices_tensor].contiguous() initial_state = ssm_state[
initial_state[~has_initial_state, ...] = 0 non_spec_state_indices_tensor].contiguous()
else:
initial_state_shape = non_spec_state_indices_tensor.shape + ssm_state.shape[1: ]
initial_state = torch.empty(initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device)
for i in range(non_spec_state_indices_tensor.shape[0]):
initial_state[i] = ssm_state[non_spec_state_indices_tensor[i]]
initial_state = initial_state * has_initial_state.view(has_initial_state.shape[0], 1, 1, 1)
initial_state = initial_state.transpose(-1, -2).contiguous() initial_state = initial_state.transpose(-1, -2).contiguous()
if self.num_v_heads // self.num_k_heads > 1:
query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
( (
core_attn_out_non_spec, core_attn_out_non_spec,
last_recurrent_state, last_recurrent_state,
@@ -635,9 +637,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
cu_seqlens=non_spec_query_start_loc, cu_seqlens=non_spec_query_start_loc,
) )
# Init cache # Init cache
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous() last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view(
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1])
ssm_state.dtype) cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1])
xtorch_ops.reshape_and_cache_flash(
last_recurrent_state,
last_recurrent_state,
cast_ssm_state,
cast_ssm_state,
non_spec_state_indices_tensor)
elif attn_metadata.num_decodes > 0: elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = ( core_attn_out_non_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule( fused_recurrent_gated_delta_rule(