# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Qwen3Next model.""" from collections.abc import Iterable from itertools import islice from typing import Optional import kunlun_ops import torch import torch.nn.functional as F from einops import rearrange from torch import nn from transformers.activations import ACT2FN from vllm.attention import AttentionBackend, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config, ) from vllm.distributed import ( divide, get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE # yapf conflicts with isort for this block # yapf: disable from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm # yapf: enable from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader, ) from vllm.model_executor.models.interfaces import ( HasInnerState, IsHybrid, MixtureOfExperts, SupportsLoRA, SupportsPP, ) from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, sequence_parallel_chunk, ) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops.attention.layer import Attention from vllm_kunlun.ops.fla import ( RMSNormGated, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule, ) from vllm_kunlun.ops.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from vllm_kunlun.v1.attention.backends.gdn_attn import GDNAttentionMetadata @torch.compile(dynamic=True, backend="aot_eager") def get_masked_input_and_mask_kunlun( input_: torch.Tensor, org_vocab_start_index: int, org_vocab_end_index: int, num_org_vocab_padding: int, added_vocab_start_index: int, added_vocab_end_index: int, ) -> tuple[torch.Tensor, torch.Tensor]: # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) added_vocab_mask = (input_ >= added_vocab_start_index) & ( input_ < added_vocab_end_index ) added_offset = ( added_vocab_start_index - (org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding ) valid_offset = (org_vocab_start_index * org_vocab_mask) + ( added_offset * added_vocab_mask ) vocab_mask = org_vocab_mask | added_vocab_mask input_ = vocab_mask * (input_ - valid_offset) return input_, ~vocab_mask get_masked_input_and_mask = get_masked_input_and_mask_kunlun logger = init_logger(__name__) KVCache = tuple[torch.Tensor, torch.Tensor] class Qwen3NextMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": raise ValueError( f"Unsupported activation: {hidden_act}. " "Only silu is supported for now." ) self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class Qwen3NextSparseMoeBlock(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config parallel_config = vllm_config.parallel_config quant_config = vllm_config.quant_config self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group self.ep_rank = self.ep_group.rank() self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." ) # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = self.ep_rank * self.n_local_physical_experts self.physical_expert_end = ( self.physical_expert_start + self.n_local_physical_experts ) self.experts = FusedMoE( num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, ) self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate", ) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen3NextMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=self.experts.must_reduce_shared_expert_outputs(), prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) shared_output = None if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: shared_output = ( F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output ) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0 ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states ) ) return final_hidden_states.view(orig_shape) class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): @property def mamba_type(self) -> str: return "linear_attention" def get_attn_backend(self) -> type["AttentionBackend"]: from vllm_kunlun.v1.attention.backends.gdn_attn import GDNAttentionBackend return GDNAttentionBackend def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( self.model_config.dtype, self.cache_config.mamba_cache_dtype ) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.gated_delta_net_state_shape( self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, self.head_v_dim, self.conv_kernel_size, self.num_spec, ) def __init__( self, config: Qwen3NextConfig, model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, prefix: str = "", ) -> None: super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.hidden_size = config.hidden_size self.num_v_heads = config.linear_num_value_heads self.num_k_heads = config.linear_num_key_heads self.head_k_dim = config.linear_key_head_dim self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads self.value_dim = self.head_v_dim * self.num_v_heads self.conv_kernel_size = config.linear_conv_kernel_dim self.layer_idx = extract_layer_index(prefix) self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] self.layer_norm_epsilon = config.rms_norm_eps self.prefix = prefix self.config = config self.model_config = model_config self.cache_config = cache_config self.quant_config = quant_config self.speculative_config = speculative_config self.num_spec = ( self.speculative_config.num_speculative_tokens if self.speculative_config else 0 ) # QKV self.conv_dim = self.key_dim * 2 + self.value_dim self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.conv_dim, bias=False, prefix=f"{prefix}.conv1d", ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) # projection of the input hidden states self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 self.projection_size_ba = self.num_v_heads * 2 self.in_proj_qkvz = ColumnParallelLinear( input_size=self.hidden_size, output_size=self.projection_size_qkvz, bias=False, quant_config=quant_config, prefix=f"{prefix}.in_proj_qkvz", ) # ba_proj doesn't support blockwise fp8 quantization. self.in_proj_ba = ColumnParallelLinear( input_size=self.hidden_size, output_size=self.projection_size_ba, bias=False, quant_config=quant_config, prefix=f"{prefix}.in_proj_ba", ) query_key_settings = (self.key_dim, 0, False) value_settings = (self.value_dim, 0, False) delattr(self.conv1d.weight, "weight_loader") set_weight_attrs( self.conv1d.weight, { "weight_loader": mamba_v2_sharded_weight_loader( [ query_key_settings, query_key_settings, value_settings, ], self.tp_size, self.tp_rank, ) }, ) # selective projection used to make dt, B and C input dependant # time step projection (discretization) # instantiate once and copy inv_dt in init_weights of PretrainedModel self.dt_bias = nn.Parameter( torch.ones(self.num_v_heads // self.tp_size), ) self.A_log = nn.Parameter( torch.empty( divide(self.num_v_heads, self.tp_size), dtype=torch.float32, ) ) set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.norm = RMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, group_size=None, norm_before_gate=True, device=current_platform.current_device(), dtype=torch.get_default_dtype(), # config.torch_dtype, ) self.out_proj = RowParallelLinear( self.value_dim, self.hidden_size, bias=False, input_is_parallel=True, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self def fix_query_key_value_ordering( self, mixed_qkvz, mixed_ba, ): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. """ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, ( self.head_k_dim + self.head_k_dim + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // self.num_k_heads ), ) new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, 2 * self.num_v_heads // self.num_k_heads, ) mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) mixed_ba = mixed_ba.view(*new_tensor_shape_ba) split_arg_list_qkvz = [ self.head_k_dim, self.head_k_dim, (self.num_v_heads // self.num_k_heads * self.head_v_dim), (self.num_v_heads // self.num_k_heads * self.head_v_dim), ] split_arg_list_ba = [ self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads, ] # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) b, a = torch.split(mixed_ba, split_arg_list_ba, dim=2) # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] value = value.reshape(value.size(0), -1, self.head_v_dim) z = z.reshape(z.size(0), -1, self.head_v_dim) b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) return query, key, value, z, b, a def rearrange_mixed_qkv(self, mixed_qkv): if mixed_qkv is None: return None, None, None query, key, value = torch.split( mixed_qkv, [ self.key_dim // self.tp_size, self.key_dim // self.tp_size, self.value_dim // self.tp_size, ], dim=-1, ) query, key = map( lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), (query, key), ) value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) return query, key, value def forward( self, hidden_states: torch.Tensor, output: torch.Tensor, ): return torch.ops.vllm.gdn_attention( hidden_states, output, self.prefix, ) def _forward( self, hidden_states: torch.Tensor, output: torch.Tensor, ): forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: # V1 profile run return assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, GDNAttentionMetadata) has_initial_state = attn_metadata.has_initial_state spec_query_start_loc = attn_metadata.spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc spec_sequence_masks = attn_metadata.spec_sequence_masks spec_token_masks = attn_metadata.spec_token_masks spec_state_indices_tensor = ( attn_metadata.spec_state_indices_tensor ) # noqa: E501 non_spec_state_indices_tensor = ( attn_metadata.non_spec_state_indices_tensor ) # noqa: E501 non_spec_state_indices_tensor_cpu = ( attn_metadata.non_spec_state_indices_tensor_cpu ) self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0] ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens if spec_token_masks is not None: spec_token_masks = spec_token_masks[:num_actual_tokens] # 1. Set up dimensions for reshapes later projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens]) query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) query, key, value = map( lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) ) mixed_qkv = torch.cat((query, key, value), dim=-1) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: mixed_qkv_spec = mixed_qkv[spec_token_masks] mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] else: mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv # 2.1: process the mutli-query part if spec_sequence_masks is not None: mixed_qkv_spec = causal_conv1d_update( mixed_qkv_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, conv_state_indices=spec_state_indices_tensor[:, 0][ : attn_metadata.num_spec_decodes ], num_accepted_tokens=num_accepted_tokens, query_start_loc=spec_query_start_loc, max_query_len=spec_state_indices_tensor.size(-1), validate_data=False, ) # 2.2: process the remaining part if attn_metadata.num_prefills > 0: # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" mixed_qkv_non_spec = causal_conv1d_fn( mixed_qkv_non_spec, conv_weights, self.conv1d.bias, activation=self.activation, conv_states=conv_state, has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, conv_state_indices=non_spec_state_indices_tensor[ : attn_metadata.num_decodes ], conv_state_indices_cpu=non_spec_state_indices_tensor_cpu[ : attn_metadata.num_decodes ], validate_data=True, ) else: mixed_qkv_non_spec = None query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( mixed_qkv_non_spec ) beta = b.sigmoid() g = ops.fused_gdn_gating(self.A_log.float(), a, self.dt_bias.float()) g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: g_spec = g beta_spec = beta g_non_spec = None beta_non_spec = None else: g_spec = g[:, spec_token_masks] beta_spec = beta[:, spec_token_masks] g_non_spec = g[:, ~spec_token_masks] beta_non_spec = beta[:, ~spec_token_masks] else: g_spec = None beta_spec = None g_non_spec = g beta_non_spec = beta # 3. Recurrent attention # 3.1: process the mutlti-query part if spec_sequence_masks is not None: core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( q=query_spec, k=key_spec, v=value_spec, g=g_spec, beta=beta_spec, initial_state=ssm_state, inplace_final_state=True, cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], ssm_state_indices=spec_state_indices_tensor, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=True, ) else: core_attn_out_spec, last_recurrent_state = None, None # 3.2: process the remaining part if attn_metadata.num_prefills > 0: slot_mapping = torch.full( (ssm_state.shape[0],), -1, dtype=torch.int32, device="cuda" ) slot_mapping[non_spec_state_indices_tensor] = torch.arange( len(non_spec_state_indices_tensor), dtype=torch.int32, device="cuda" ) initial_state_shape = ( non_spec_state_indices_tensor.shape + ssm_state.shape[1:] ) initial_state = torch.empty( initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device ) initial_state = initial_state.view( initial_state.shape[0], 1, -1, initial_state.shape[-1] ) cast_ssm_state = ssm_state.view(ssm_state.shape[0], -1, ssm_state.shape[-1]) kunlun_ops.reshape_and_cache_flash( cast_ssm_state, cast_ssm_state, initial_state, initial_state, slot_mapping, ) initial_state = initial_state.view(initial_state_shape) initial_state = initial_state * has_initial_state.view( has_initial_state.shape[0], 1, 1, 1 ) initial_state = initial_state.transpose(-1, -2).contiguous() ( core_attn_out_non_spec, last_recurrent_state, ) = chunk_gated_delta_rule( q=query_non_spec, k=key_non_spec, v=value_non_spec, g=g_non_spec, beta=beta_non_spec, initial_state=initial_state, output_final_state=True, use_qk_l2norm_in_kernel=True, cu_seqlens=non_spec_query_start_loc, ) # Init cache last_recurrent_state = ( last_recurrent_state.transpose(-1, -2) .contiguous() .to(ssm_state.dtype) .view(last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1]) ) cast_ssm_state = ssm_state.view( ssm_state.shape[0], 1, -1, ssm_state.shape[-1] ) kunlun_ops.reshape_and_cache_flash( last_recurrent_state, last_recurrent_state, cast_ssm_state, cast_ssm_state, non_spec_state_indices_tensor, ) elif attn_metadata.num_decodes > 0: core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule( q=query_non_spec, k=key_non_spec, v=value_non_spec, g=g_non_spec, beta=beta_non_spec, initial_state=ssm_state, inplace_final_state=True, cu_seqlens=non_spec_query_start_loc[ : attn_metadata.num_decodes + 1 ], ssm_state_indices=non_spec_state_indices_tensor, use_qk_l2norm_in_kernel=True, ) ) else: core_attn_out_non_spec, last_recurrent_state = None, None # Merge core attention output if spec_sequence_masks is not None and core_attn_out_non_spec is not None: core_attn_out = torch.empty( (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), dtype=core_attn_out_non_spec.dtype, device=core_attn_out_non_spec.device, ) core_attn_out[:, spec_token_masks] = core_attn_out_spec core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec elif spec_sequence_masks is not None: core_attn_out = core_attn_out_spec else: core_attn_out = core_attn_out_non_spec z_shape_og = z.shape # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_actual_tokens], _ = self.out_proj(core_attn_out) class Qwen3NextAttention(nn.Module): def __init__( self, config: Qwen3NextConfig, model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_key_value_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.dual_chunk_attention_config = getattr( config, "dual_chunk_attention_config", None ) self.attn_output_gate = getattr(config, "attn_output_gate", True) self.qkv_proj = QKVParallelLinear( config.hidden_size, self.head_dim, self.total_num_heads * (1 + self.attn_output_gate), self.total_num_kv_heads, bias=getattr(config, "qkv_bias", False), quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( head_size=self.head_dim, rotary_dim=self.head_dim, max_position=config.max_position_embeddings, base=config.rope_theta, rope_scaling=config.rope_scaling, partial_rotary_factor=config.partial_rotary_factor, dual_chunk_attention_config=self.dual_chunk_attention_config, ) self.rotary_dim = self.head_dim if config.partial_rotary_factor < 1.0: self.rotary_dim = int(self.rotary_dim * config.partial_rotary_factor) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", **( { "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": self.dual_chunk_attention_config, } if self.dual_chunk_attention_config else {} ), ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, output: torch.Tensor, hidden_states: torch.Tensor, ): qkv, _ = self.qkv_proj(hidden_states) q, k, v, gate = torch.ops.xspeedgate_ops.split_norm_rope_neox( qkv=qkv, q_weights=self.q_norm.weight, k_weights=self.k_norm.weight, positions=positions, cos_sin_cache=self.rotary_emb.cos_sin_cache, q_size=self.q_size, kv_size=self.kv_size, num_q_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, rotary_dim=self.rotary_dim, attn_output_gate=True, ) attn_output = self.attn(q, k, v) if self.attn_output_gate: gate = torch.sigmoid(gate) attn_output = attn_output * gate output[:], _ = self.o_proj(attn_output) class Qwen3NextDecoderLayer(nn.Module): def __init__( self, vllm_config: VllmConfig, layer_type: str, prefix: str = "", ) -> None: super().__init__() config = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config speculative_config = vllm_config.speculative_config self.layer_type = layer_type self.layer_idx = extract_layer_index(prefix) if self.layer_type == "linear_attention": self.linear_attn = Qwen3NextGatedDeltaNet( config, model_config=model_config, cache_config=cache_config, quant_config=quant_config, speculative_config=speculative_config, prefix=f"{prefix}.linear_attn", ) elif self.layer_type == "full_attention": self.self_attn = Qwen3NextAttention( config, model_config=model_config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) else: raise ValueError(f"Invalid layer_type {self.layer_type}") mlp_only_layers = ( [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) if (self.layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (self.layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3NextSparseMoeBlock( vllm_config=vllm_config, prefix=f"{prefix}.mlp", ) else: self.mlp = Qwen3NextMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, ) self.input_layernorm = Qwen3NextRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = Qwen3NextRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.layer_scale = getattr(config, "layer_scale", False) if self.layer_scale: self.attn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, config.hidden_size, dtype=config.torch_dtype, ), ) self.ffn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, config.hidden_size, dtype=config.torch_dtype, ), ) def forward( self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], positions: torch.Tensor = None, **kwargs: object, ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) self_attention_output = torch.empty_like(hidden_states) if self.layer_type == "linear_attention": self.linear_attn( hidden_states=hidden_states, output=self_attention_output, ) elif self.layer_type == "full_attention": self.self_attn( hidden_states=hidden_states, output=self_attention_output, positions=positions, ) else: raise ValueError("Invalid layer_type") hidden_states = self_attention_output if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 ) else: hidden_states = hidden_states * ( self.attn_layer_scale.to(hidden_states.dtype) + 1 ) # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 ) else: assert len(hidden_states.shape) == len( self.ffn_layer_scale.shape ), f"shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}" # noqa: E501 hidden_states = hidden_states * ( self.ffn_layer_scale.to(hidden_states.dtype) + 1 ) return hidden_states, residual @support_torch_compile class Qwen3NextModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen3NextConfig = vllm_config.model_config.hf_config parallel_config = vllm_config.parallel_config lora_config = vllm_config.lora_config eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config lora_vocab = ( (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 ) self.vocab_size = config.vocab_size + lora_vocab self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) def get_layer(prefix: str): return Qwen3NextDecoderLayer( vllm_config, layer_type=config.layer_types[extract_layer_index(prefix)], prefix=prefix, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) if get_pp_group().is_last_rank: self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, num_redundant_experts=self.num_redundant_experts, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if name.startswith("mtp."): continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # name = apply_attn_prefix(name, params_dict) if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. if ( name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen3NextForCausalLM( nn.Module, HasInnerState, SupportsLoRA, SupportsPP, MixtureOfExperts, IsHybrid ): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config assert ( not cache_config.enable_prefix_caching ), "Qwen3Next currently does not support prefix caching" self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config self.model = Qwen3NextModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=( DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size ), prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor( self.unpadded_vocab_size, config.vocab_size ) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) # Set MoE hyperparameters self.expert_weights = [] self.moe_layers: list[FusedMoE] = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue assert isinstance(layer, Qwen3NextDecoderLayer) if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): example_layer = layer.mlp self.moe_layers.append(layer.mlp.experts) if example_layer is None: raise RuntimeError("No Qwen3Next layer found in the model.layers.") self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 self.num_logical_experts = example_layer.n_logical_experts self.num_physical_experts = example_layer.n_physical_experts self.num_local_physical_experts = example_layer.n_local_physical_experts self.num_routed_experts = example_layer.n_routed_experts self.num_redundant_experts = example_layer.n_redundant_experts def set_eplb_state( self, expert_load_view: torch.Tensor, logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, ) -> None: for layer_idx, layer in enumerate(self.moe_layers): # Register the expert weights. self.expert_weights.append(layer.get_expert_weights()) layer.set_eplb_state( moe_layer_idx=layer_idx, expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, ) def update_physical_experts_metadata( self, num_physical_experts: int, num_local_physical_experts: int, ) -> None: assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): moe = layer.mlp moe.n_local_physical_experts = num_local_physical_experts moe.n_physical_experts = num_physical_experts moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype ) @classmethod def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig" ) -> tuple[tuple[int, int], tuple[int, int]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config tp_size = parallel_config.tensor_parallel_size num_spec = ( vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0 ) return MambaStateShapeCalculator.gated_delta_net_state_shape( tp_size, hf_config.linear_num_key_heads, hf_config.linear_num_value_heads, hf_config.linear_key_head_dim, hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim, num_spec, ) def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=["mtp."], ) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() def gdn_attention( hidden_states: torch.Tensor, output: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] self._forward(hidden_states=hidden_states, output=output) def gdn_attention_fake( hidden_states: torch.Tensor, output: torch.Tensor, layer_name: str, ) -> None: return direct_register_custom_op( op_name="gdn_attention", op_func=gdn_attention, mutates_args=["output"], fake_impl=gdn_attention_fake, ) # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) @triton.jit def fused_gdn_gating_kernel( g, A_log, a, dt_bias, seq_len, NUM_HEADS: tl.constexpr, beta: tl.constexpr, threshold: tl.constexpr, BLK_HEADS: tl.constexpr, ): i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off mask = head_off < NUM_HEADS blk_A_log = tl.load(A_log + head_off, mask=mask) blk_a = tl.load(a + off, mask=mask) blk_bias = tl.load(dt_bias + head_off, mask=mask) # If the model is loaded in fp16, without the .float() here, A might be -inf x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) softplus_x = tl.where( beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x ) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) def fused_gdn_gating( A_log: torch.Tensor, a: torch.Tensor, dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0, ) -> torch.Tensor: batch, num_heads = a.shape seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) g = torch.empty_like(a, dtype=torch.float32) fused_gdn_gating_kernel[grid]( g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 ) return g