diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 47b6d3ea..f5b4b8a1 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -51,6 +51,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata +from vllm_ascend.utils import vllm_version_is + from vllm.model_executor.models.qwen3_next import ( # isort: skip Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM, Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock, @@ -201,7 +203,11 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): spec_query_start_loc = attn_metadata.spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_masks = attn_metadata.spec_token_masks + if vllm_version_is("0.11.0"): + spec_token_masks = attn_metadata.spec_token_masks + else: + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 self_kv_cache = self.kv_cache[forward_context.virtual_engine] @@ -216,8 +222,9 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): # 1. Set up dimensions for reshapes later projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) - if spec_token_masks is not None: - spec_token_masks = spec_token_masks[:num_actual_tokens] + if vllm_version_is("0.11.0"): + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] projected_states_qkvz, projected_states_ba = torch.split( projected_states, [ @@ -242,8 +249,13 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: - mixed_qkv_spec = mixed_qkv[spec_token_masks] - mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + if vllm_version_is("0.11.0"): + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select( + 0, non_spec_token_indx) else: mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv @@ -293,10 +305,16 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): g_non_spec = None beta_non_spec = None else: - g_spec = g[:, spec_token_masks] - beta_spec = beta[:, spec_token_masks] - g_non_spec = g[:, ~spec_token_masks] - beta_non_spec = beta[:, ~spec_token_masks] + if vllm_version_is("0.11.0"): + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) else: g_spec = None beta_spec = None @@ -404,8 +422,14 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): dtype=core_attn_out_non_spec.dtype, device=core_attn_out_non_spec.device, ) - core_attn_out[:, spec_token_masks] = core_attn_out_spec - core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + if vllm_version_is("0.11.0"): + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + else: + core_attn_out.index_copy_(1, spec_token_indx, + core_attn_out_spec) + core_attn_out.index_copy_(1, non_spec_token_indx, + core_attn_out_non_spec) elif spec_sequence_masks is not None: core_attn_out = core_attn_out_spec else: