diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 2f3585ea..a3d1f2c4 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -16,8 +16,7 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fla.ops import RMSNormGated -from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule +from vllm.model_executor.layers.fla.ops import RMSNormGated, chunk from vllm.model_executor.layers.fla.ops.fused_recurrent import \ fused_recurrent_gated_delta_rule from vllm.model_executor.layers.fused_moe import FusedMoE @@ -35,8 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \ mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops import causal_conv1d from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -252,7 +250,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): mixed_qkv_spec = mixed_qkv_spec.view( attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') - mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec = causal_conv1d.causal_conv1d_update( mixed_qkv_spec, conv_state, conv_weights, @@ -269,7 +267,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): if attn_metadata.num_prefills > 0: # - "cache_indices" updates the conv_state cache in positions # pointed to by "mamba_cache_params.state_indices_tensor" - mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn( mixed_qkv_non_spec.transpose(0, 1), conv_weights, self.conv1d.bias, @@ -280,7 +278,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): query_start_loc=non_spec_query_start_loc, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, @@ -364,7 +362,7 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): ( cur_core_attn_out_non_spec, cur_last_recurrent_state, - ) = chunk_gated_delta_rule( + ) = chunk.chunk_gated_delta_rule( query=cur_q, key=cur_k, value=cur_v,