diff --git a/vllm_kunlun/models/qwen3_next.py b/vllm_kunlun/models/qwen3_next.py index 07ef498..fea85da 100644 --- a/vllm_kunlun/models/qwen3_next.py +++ b/vllm_kunlun/models/qwen3_next.py @@ -3,7 +3,7 @@ """Inference-only Qwen3Next model.""" from collections.abc import Iterable from itertools import islice -from typing import Optional +from typing import Optional, Union import torch import torch.nn.functional as F @@ -69,10 +69,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops - - -from typing import Optional, Union from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask +import xtorch_ops @torch.compile(dynamic=True, backend="aot_eager") @@ -613,13 +611,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # 3.2: process the remaining part if attn_metadata.num_prefills > 0: - initial_state = ssm_state[ - non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 + if non_spec_state_indices_tensor.shape[0] > 100: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + else: + initial_state_shape = non_spec_state_indices_tensor.shape + ssm_state.shape[1: ] + initial_state = torch.empty(initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device) + for i in range(non_spec_state_indices_tensor.shape[0]): + initial_state[i] = ssm_state[non_spec_state_indices_tensor[i]] + + initial_state = initial_state * has_initial_state.view(has_initial_state.shape[0], 1, 1, 1) initial_state = initial_state.transpose(-1, -2).contiguous() - if self.num_v_heads // self.num_k_heads > 1: - query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) ( core_attn_out_non_spec, last_recurrent_state, @@ -635,9 +637,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): cu_seqlens=non_spec_query_start_loc, ) # Init cache - last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous() - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype) + last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view( + last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1]) + cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1]) + xtorch_ops.reshape_and_cache_flash( + last_recurrent_state, + last_recurrent_state, + cast_ssm_state, + cast_ssm_state, + non_spec_state_indices_tensor) elif attn_metadata.num_decodes > 0: core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule(