[Bugfix] use module-level import for patched function in Qwen3Next (#4354)

### What this PR does / why we need it?

**Problem**: The Qwen3Next model implementation currently imports
chunk_gated_delta_rule directly using `from ... import ...`

In frameworks like `verl`, the model file is often imported before
`vllm-ascend` initializes and applies its patches. This causes the model
to permanently hold a reference to the original (unpatched) vLLM kernel,
resulting in execution errors on Ascend devices even if the patch runs
later.

**Solution**: Changed the import style to `from vllm...ops import chunk`
and call `chunk.chunk_gated_delta_rule().`

This ensures that the function lookup happens at runtime (dynamic
dispatch), allowing the model to correctly pick up the patched function
regardless of import order.

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: zjchenn <zjchenn@gmail.com>
This commit is contained in:
Zhijun Chen
2025-11-25 20:15:43 +08:00
committed by GitHub
parent 941d54a2ce
commit 463910e686

View File

@@ -16,8 +16,7 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import RMSNormGated
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
from vllm.model_executor.layers.fla.ops import RMSNormGated, chunk
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
fused_recurrent_gated_delta_rule
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -35,8 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \
mamba_v2_sharded_weight_loader
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops import causal_conv1d
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -252,7 +250,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
mixed_qkv_spec = mixed_qkv_spec.view(
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec = causal_conv1d.causal_conv1d_update(
mixed_qkv_spec,
conv_state,
conv_weights,
@@ -269,7 +267,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
if attn_metadata.num_prefills > 0:
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn(
mixed_qkv_non_spec.transpose(0, 1),
conv_weights,
self.conv1d.bias,
@@ -280,7 +278,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
query_start_loc=non_spec_query_start_loc,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update(
mixed_qkv_non_spec,
conv_state,
conv_weights,
@@ -364,7 +362,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
(
cur_core_attn_out_non_spec,
cur_last_recurrent_state,
) = chunk_gated_delta_rule(
) = chunk.chunk_gated_delta_rule(
query=cur_q,
key=cur_k,
value=cur_v,