diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index cf51706..21ce47b 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -53,4 +53,4 @@ def register_model(): ) ModelRegistry.register_model( "Qwen3NextForCausalLM", - "vllm_ascend.models.qwen3_next:Qwen3NextForCausalLM") + "vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM") diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 175a529..0d12476 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -14,10 +14,9 @@ from vllm.attention import AttentionBackend, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config) -from vllm.distributed import (divide, get_pp_group, - get_tensor_model_parallel_rank, +from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import ForwardContext, get_forward_context +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fla.ops import RMSNormGated from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule from vllm.model_executor.layers.fla.ops.fused_recurrent import \ @@ -44,27 +43,24 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - MixtureOfExperts, - SupportsLoRA, SupportsPP) -from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP -from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention, - Qwen3NextSparseMoeBlock, - fused_gdn_gating) from vllm.model_executor.models.utils import ( - AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, - make_layers, maybe_prefix) + PPMissingLayer, extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig -from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata +from vllm.model_executor.models.qwen3_next import Qwen3NextAttention # isort: skip +from vllm.model_executor.models.qwen3_next import Qwen3NextDecoderLayer # isort: skip +from vllm.model_executor.models.qwen3_next import Qwen3NextForCausalLM # isort: skip +from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet # isort: skip +from vllm.model_executor.models.qwen3_next import Qwen3NextModel # isort: skip +from vllm.model_executor.models.qwen3_next import Qwen3NextSparseMoeBlock # isort: skip +from vllm.model_executor.models.qwen3_next import fused_gdn_gating # isort: skip -class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): + +class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): @property def mamba_type(self) -> str: @@ -80,14 +76,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.gated_delta_net_state_shape( - self.tp_size, - self.num_k_heads, - self.num_v_heads, - self.head_k_dim, - self.head_v_dim, - self.conv_kernel_size, - self.num_spec, - use_v1=True) + self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, + self.head_v_dim, self.conv_kernel_size, self.num_spec) def __init__( self, @@ -98,7 +88,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): speculative_config: Optional[SpeculativeConfig] = None, prefix: str = "", ) -> None: - super().__init__() + nn.Module.__init__(self) self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.hidden_size = config.hidden_size @@ -195,85 +185,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - def fix_query_key_value_ordering( - self, - mixed_qkvz, - mixed_ba, - ): - """ - Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. - """ - new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - self.num_k_heads // self.tp_size, - (self.head_k_dim + self.head_k_dim + - (self.head_v_dim + self.head_v_dim) * self.num_v_heads // - self.num_k_heads), - ) - new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( - self.num_k_heads // self.tp_size, - 2 * self.num_v_heads // self.num_k_heads, - ) - - mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - - split_arg_list_qkvz = [ - self.head_k_dim, - self.head_k_dim, - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - ] - split_arg_list_ba = [ - self.num_v_heads // self.num_k_heads, - self.num_v_heads // self.num_k_heads - ] - - # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] - # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], - # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] - (query, key, value, z) = torch.split(mixed_qkvz, - split_arg_list_qkvz, - dim=2) - (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) - - # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - value = value.reshape(value.size(0), -1, self.head_v_dim) - z = z.reshape(z.size(0), -1, self.head_v_dim) - b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) - a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) - - return query, key, value, z, b, a - - def rearrange_mixed_qkv(self, mixed_qkv): - if mixed_qkv is None: - return None, None, None - query, key, value = torch.split( - mixed_qkv, - [ - self.key_dim // self.tp_size, - self.key_dim // self.tp_size, - self.value_dim // self.tp_size, - ], - dim=-1, - ) - query, key = map( - lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim), - (query, key)) - value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim) - return query, key, value - - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - cache_params: Optional[MambaCacheParams] = None, - ): - return torch.ops.vllm.npu_gdn_attention( - hidden_states, - output, - self.prefix, - ) - def _forward( self, hidden_states: torch.Tensor, @@ -340,24 +251,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv - # 2.1: process the mutli-query part - # if spec_sequence_masks is not None: - # mixed_qkv_spec = mixed_qkv_spec.view( - # attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) - # mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') - # mixed_qkv_spec = causal_conv1d_update( - # mixed_qkv_spec, - # conv_state, - # conv_weights, - # self.conv1d.bias, - # self.activation, - # conv_state_indices=spec_state_indices_tensor[:, 0] - # [:attn_metadata.num_spec_decodes], - # num_accepted_tokens=num_accepted_tokens, - # validate_data=False, - # ) - # mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') - # 2.2: process the remaining part if attn_metadata.num_prefills > 0: # - "cache_indices" updates the conv_state cache in positions @@ -532,7 +425,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): output[:num_actual_tokens], _ = self.out_proj(core_attn_out) -class Qwen3NextDecoderLayer(nn.Module): +class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer): def __init__( self, @@ -545,14 +438,14 @@ class Qwen3NextDecoderLayer(nn.Module): prefix: str = "", enable_eplb: bool = False, ) -> None: - super().__init__() + nn.Module.__init__(self) self.config = config self.layer_type = layer_type self.layer_idx = extract_layer_index(prefix) if self.layer_type == "linear_attention": - self.linear_attn = Qwen3NextGatedDeltaNet( + self.linear_attn = CustomQwen3NextGatedDeltaNet( config, model_config=model_config, cache_config=cache_config, @@ -611,69 +504,12 @@ class Qwen3NextDecoderLayer(nn.Module): dtype=config.torch_dtype, ), ) - def forward( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - positions: torch.Tensor = None, - **kwargs: object, - ): - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - - self_attention_output = torch.empty_like(hidden_states) - if self.layer_type == "linear_attention": - self.linear_attn( - hidden_states=hidden_states, - output=self_attention_output, - ) - elif self.layer_type == "full_attention": - self.self_attn( - hidden_states=hidden_states, - output=self_attention_output, - positions=positions, - ) - else: - raise ValueError("Invalid layer_type") - hidden_states = self_attention_output - - if self.layer_scale: - if len(hidden_states.shape) == 2: - hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype)[0] + 1) - else: - hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype) + 1) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - - if self.layer_scale: - if len(hidden_states.shape) == 2: - hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1) - else: - assert len(hidden_states.shape) == len( - self.ffn_layer_scale.shape - ), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501 - hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype) + 1) - - return hidden_states, residual - @support_torch_compile -class Qwen3NextModel(nn.Module): +class CustomQwen3NextModel(Qwen3NextModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - + nn.Module.__init__(self) config: Qwen3NextConfig = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config @@ -697,7 +533,7 @@ class Qwen3NextModel(nn.Module): ) def get_layer(prefix: str): - return Qwen3NextDecoderLayer( + return CustomQwen3NextDecoderLayer( config, layer_type=config.layer_types[extract_layer_index(prefix)], model_config=model_config, @@ -717,52 +553,6 @@ class Qwen3NextModel(nn.Module): self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - for layer in self.layers: - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - ) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, - num_redundant_experts=self.num_redundant_experts) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -842,10 +632,10 @@ class Qwen3NextModel(nn.Module): return loaded_params -class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - MixtureOfExperts, IsHybrid): +class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -856,12 +646,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, "Qwen3Next currently does not support prefix caching" assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" self.quant_config = vllm_config.quant_config - - super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = Qwen3NextModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = CustomQwen3NextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -904,127 +692,3 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.num_local_physical_experts = example_layer.n_local_physical_experts self.num_routed_experts = example_layer.n_routed_experts self.num_redundant_experts = example_layer.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) - for layer in self.model.layers: - if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ): - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - - return hidden_states - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, - vllm_config.cache_config.mamba_cache_dtype) - - @classmethod - def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, int], tuple[int, int]]: - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_config - tp_size = parallel_config.tensor_parallel_size - num_spec = (vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config else 0) - return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - use_v1=True) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata=None, # type: ignore - ) -> Optional[torch.Tensor]: - return self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=["mtp."], - ) - return loader.load_weights(weights) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() - - -def npu_gdn_attention( - hidden_states: torch.Tensor, - output: torch.Tensor, - layer_name: str, -) -> None: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - self._forward(hidden_states=hidden_states, output=output) - - -def npu_gdn_attention_fake( - hidden_states: torch.Tensor, - output: torch.Tensor, - layer_name: str, -) -> None: - return - - -direct_register_custom_op( - op_name="npu_gdn_attention", - op_func=npu_gdn_attention, - mutates_args=["output"], - fake_impl=npu_gdn_attention_fake, - dispatch_key=current_platform.dispatch_key, -)