From bb5f16d926ad673c5a67f766bbbaff25b7b4e031 Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Sat, 25 Oct 2025 18:03:36 +0800 Subject: [PATCH] [BugFix] Fix Qwen3-next break (#3428) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Fix Qwen3NextGatedDeltaNet, caused by https://github.com/vllm-project/vllm/pull/26437 ### How was this patch tested? ``` def main(): prompts = [ "窗前明月光,", "The president of the United States is Mr.", "The capital of France is", "The future of AI is", "感时花溅泪,", "家书抵万金啥意思?", "plz tell me a story: ", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, trust_remote_code=True, max_model_len=256, gpu_memory_utilization=0.7, block_size=64 ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/models/qwen3_next.py | 46 ++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 47b6d3ea..f5b4b8a1 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -51,6 +51,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata +from vllm_ascend.utils import vllm_version_is + from vllm.model_executor.models.qwen3_next import ( # isort: skip Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM, Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock, @@ -201,7 +203,11 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): spec_query_start_loc = attn_metadata.spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_masks = attn_metadata.spec_token_masks + if vllm_version_is("0.11.0"): + spec_token_masks = attn_metadata.spec_token_masks + else: + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 self_kv_cache = self.kv_cache[forward_context.virtual_engine] @@ -216,8 +222,9 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): # 1. Set up dimensions for reshapes later projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) - if spec_token_masks is not None: - spec_token_masks = spec_token_masks[:num_actual_tokens] + if vllm_version_is("0.11.0"): + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] projected_states_qkvz, projected_states_ba = torch.split( projected_states, [ @@ -242,8 +249,13 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: - mixed_qkv_spec = mixed_qkv[spec_token_masks] - mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + if vllm_version_is("0.11.0"): + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select( + 0, non_spec_token_indx) else: mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv @@ -293,10 +305,16 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): g_non_spec = None beta_non_spec = None else: - g_spec = g[:, spec_token_masks] - beta_spec = beta[:, spec_token_masks] - g_non_spec = g[:, ~spec_token_masks] - beta_non_spec = beta[:, ~spec_token_masks] + if vllm_version_is("0.11.0"): + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) else: g_spec = None beta_spec = None @@ -404,8 +422,14 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): dtype=core_attn_out_non_spec.dtype, device=core_attn_out_non_spec.device, ) - core_attn_out[:, spec_token_masks] = core_attn_out_spec - core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + if vllm_version_is("0.11.0"): + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + else: + core_attn_out.index_copy_(1, spec_token_indx, + core_attn_out_spec) + core_attn_out.index_copy_(1, non_spec_token_indx, + core_attn_out_non_spec) elif spec_sequence_masks is not None: core_attn_out = core_attn_out_spec else: