From 463910e686013acc611a7e3ccb70c3a776c8ffef Mon Sep 17 00:00:00 2001 From: Zhijun Chen Date: Tue, 25 Nov 2025 20:15:43 +0800 Subject: [PATCH] [Bugfix] use module-level import for patched function in Qwen3Next (#4354) ### What this PR does / why we need it? **Problem**: The Qwen3Next model implementation currently imports chunk_gated_delta_rule directly using `from ... import ...` In frameworks like `verl`, the model file is often imported before `vllm-ascend` initializes and applies its patches. This causes the model to permanently hold a reference to the original (unpatched) vLLM kernel, resulting in execution errors on Ascend devices even if the patch runs later. **Solution**: Changed the import style to `from vllm...ops import chunk` and call `chunk.chunk_gated_delta_rule().` This ensures that the function lookup happens at runtime (dynamic dispatch), allowing the model to correctly pick up the patched function regardless of import order. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 Signed-off-by: zjchenn --- vllm_ascend/models/qwen3_next.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 2f3585ea..a3d1f2c4 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -16,8 +16,7 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fla.ops import RMSNormGated -from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule +from vllm.model_executor.layers.fla.ops import RMSNormGated, chunk from vllm.model_executor.layers.fla.ops.fused_recurrent import \ fused_recurrent_gated_delta_rule from vllm.model_executor.layers.fused_moe import FusedMoE @@ -35,8 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \ mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops import causal_conv1d from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -252,7 +250,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): mixed_qkv_spec = mixed_qkv_spec.view( attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') - mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec = causal_conv1d.causal_conv1d_update( mixed_qkv_spec, conv_state, conv_weights, @@ -269,7 +267,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): if attn_metadata.num_prefills > 0: # - "cache_indices" updates the conv_state cache in positions # pointed to by "mamba_cache_params.state_indices_tensor" - mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn( mixed_qkv_non_spec.transpose(0, 1), conv_weights, self.conv1d.bias, @@ -280,7 +278,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): query_start_loc=non_spec_query_start_loc, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, @@ -364,7 +362,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): ( cur_core_attn_out_non_spec, cur_last_recurrent_state, - ) = chunk_gated_delta_rule( + ) = chunk.chunk_gated_delta_rule( query=cur_q, key=cur_k, value=cur_v,