[Bugfix] use module-level import for patched function in Qwen3Next (#4354)
### What this PR does / why we need it?
**Problem**: The Qwen3Next model implementation currently imports
chunk_gated_delta_rule directly using `from ... import ...`
In frameworks like `verl`, the model file is often imported before
`vllm-ascend` initializes and applies its patches. This causes the model
to permanently hold a reference to the original (unpatched) vLLM kernel,
resulting in execution errors on Ascend devices even if the patch runs
later.
**Solution**: Changed the import style to `from vllm...ops import chunk`
and call `chunk.chunk_gated_delta_rule().`
This ensures that the function lookup happens at runtime (dynamic
dispatch), allowing the model to correctly pick up the patched function
regardless of import order.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: zjchenn <zjchenn@gmail.com>
This commit is contained in:
@@ -16,8 +16,7 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import RMSNormGated
|
||||
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
|
||||
from vllm.model_executor.layers.fla.ops import RMSNormGated, chunk
|
||||
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
|
||||
fused_recurrent_gated_delta_rule
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@@ -35,8 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \
|
||||
mamba_v2_sharded_weight_loader
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops import causal_conv1d
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
@@ -252,7 +250,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
mixed_qkv_spec = mixed_qkv_spec.view(
|
||||
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
|
||||
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
|
||||
mixed_qkv_spec = causal_conv1d_update(
|
||||
mixed_qkv_spec = causal_conv1d.causal_conv1d_update(
|
||||
mixed_qkv_spec,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
@@ -269,7 +267,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
mixed_qkv_non_spec = causal_conv1d_fn(
|
||||
mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn(
|
||||
mixed_qkv_non_spec.transpose(0, 1),
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
@@ -280,7 +278,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
).transpose(0, 1)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update(
|
||||
mixed_qkv_non_spec,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
@@ -364,7 +362,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
(
|
||||
cur_core_attn_out_non_spec,
|
||||
cur_last_recurrent_state,
|
||||
) = chunk_gated_delta_rule(
|
||||
) = chunk.chunk_gated_delta_rule(
|
||||
query=cur_q,
|
||||
key=cur_k,
|
||||
value=cur_v,
|
||||
|
||||
Reference in New Issue
Block a user