diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 0d12476..47b6d3e 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -51,13 +51,10 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm.model_executor.models.qwen3_next import Qwen3NextAttention # isort: skip -from vllm.model_executor.models.qwen3_next import Qwen3NextDecoderLayer # isort: skip -from vllm.model_executor.models.qwen3_next import Qwen3NextForCausalLM # isort: skip -from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet # isort: skip -from vllm.model_executor.models.qwen3_next import Qwen3NextModel # isort: skip -from vllm.model_executor.models.qwen3_next import Qwen3NextSparseMoeBlock # isort: skip -from vllm.model_executor.models.qwen3_next import fused_gdn_gating # isort: skip +from vllm.model_executor.models.qwen3_next import ( # isort: skip + Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM, + Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock, + fused_gdn_gating) class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): @@ -429,17 +426,16 @@ class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer): def __init__( self, - config: Qwen3NextConfig, + vllm_config: VllmConfig, layer_type: str, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, prefix: str = "", - enable_eplb: bool = False, ) -> None: nn.Module.__init__(self) - self.config = config + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config self.layer_type = layer_type self.layer_idx = extract_layer_index(prefix) @@ -468,12 +464,8 @@ class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer): if (self.layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (self.layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen3NextSparseMoeBlock( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb, - ) + self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3NextMLP( hidden_size=config.hidden_size, @@ -493,14 +485,14 @@ class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer): torch.zeros( 1, 1, - self.config.hidden_size, + config.hidden_size, dtype=config.torch_dtype, ), ) self.ffn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, - self.config.hidden_size, + config.hidden_size, dtype=config.torch_dtype, ), ) @@ -511,13 +503,8 @@ class CustomQwen3NextModel(Qwen3NextModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config: Qwen3NextConfig = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config lora_config = vllm_config.lora_config - speculative_config = vllm_config.speculative_config - enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts @@ -534,14 +521,9 @@ class CustomQwen3NextModel(Qwen3NextModel): def get_layer(prefix: str): return CustomQwen3NextDecoderLayer( - config, + vllm_config, layer_type=config.layer_types[extract_layer_index(prefix)], - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - speculative_config=speculative_config, prefix=prefix, - enable_eplb=enable_eplb, ) self.start_layer, self.end_layer, self.layers = make_layers(