From 4db463b1ad6edcd6b8cd500be377f65ff8e3b419 Mon Sep 17 00:00:00 2001 From: yhyang201 <47235274+yhyang201@users.noreply.github.com> Date: Sat, 19 Apr 2025 00:51:29 +0800 Subject: [PATCH] [Model] Adding Qwen3 and Qwen3MoE (#4693) --- .../layers/attention/flashinfer_backend.py | 7 +- python/sglang/srt/models/qwen2.py | 5 +- python/sglang/srt/models/qwen2_moe.py | 24 +- python/sglang/srt/models/qwen3.py | 335 ++++++++++++++ python/sglang/srt/models/qwen3_moe.py | 423 ++++++++++++++++++ 5 files changed, 780 insertions(+), 14 deletions(-) create mode 100644 python/sglang/srt/models/qwen3.py create mode 100644 python/sglang/srt/models/qwen3_moe.py diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c29275103..37b9f4d6e 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -100,8 +100,11 @@ class FlashInferAttnBackend(AttentionBackend): self.num_wrappers = 1 self.dispatch_reason = None - # Qwen2 models require higher flashinfer workspace size - if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: + # Qwen2/Qwen3 models require higher flashinfer workspace size + if ( + "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures + or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures + ): global_config.flashinfer_workspace_size = 512 * 1024 * 1024 # Allocate buffers diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index b6646b5fb..4b518eccc 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -239,6 +239,7 @@ class Qwen2Model(nn.Module): config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer, ) -> None: super().__init__() self.config = config @@ -250,9 +251,11 @@ class Qwen2Model(nn.Module): quant_config=quant_config, prefix=add_prefix("embed_tokens", prefix), ) + # Use the provided decoder layer type or default to Qwen2DecoderLayer + decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer self.layers = make_layers( config.num_hidden_layers, - lambda idx, prefix: Qwen2DecoderLayer( + lambda idx, prefix: decoder_layer_type( layer_id=idx, config=config, quant_config=quant_config, diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 52dd697ef..f1851f526 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -47,7 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, make_layers expert_distribution_recorder = ExpertDistributionRecorder() @@ -333,6 +333,7 @@ class Qwen2MoeModel(nn.Module): config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -343,16 +344,17 @@ class Qwen2MoeModel(nn.Module): config.hidden_size, prefix=add_prefix("embed_tokens", prefix), ) - self.layers = nn.ModuleList( - [ - Qwen2MoeDecoderLayer( - config, - layer_id, - quant_config=quant_config, - prefix=add_prefix(f"layers.{layer_id}", prefix), - ) - for layer_id in range(config.num_hidden_layers) - ] + # Use the provided decoder layer type or default to Qwen2MoeDecoderLayer + decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: decoder_layer_type( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("layers", prefix), ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py new file mode 100644 index 000000000..016f4b5de --- /dev/null +++ b/python/sglang/srt/models/qwen3.py @@ -0,0 +1,335 @@ +# Adapted from qwen2.py + +from functools import partial +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP +from sglang.srt.models.qwen2 import Qwen2Model +from sglang.srt.utils import add_prefix + +Qwen3Config = None + + +class Qwen3Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 1000000, + rope_scaling: Optional[Dict[str, Any]] = None, + head_dim: Optional[int] = None, + max_position_embeddings: int = 32768, + quant_config: Optional[QuantizationConfig] = None, + rms_norm_eps: float = None, + attention_bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= self.tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % self.tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=add_prefix("attn", prefix), + ) + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3DecoderLayer(nn.Module): + def __init__( + self, + config: Qwen3Config, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + head_dim = getattr(config, "head_dim", None) + self.self_attn = Qwen3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + head_dim=head_dim, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + rms_norm_eps=config.rms_norm_eps, + attention_bias=config.attention_bias, + prefix=add_prefix("self_attn", prefix), + ) + self.mlp = Qwen3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen3Model(Qwen2Model): + def __init__( + self, + config: Qwen3Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + quant_config=quant_config, + prefix=prefix, + decoder_layer_type=Qwen3DecoderLayer, + ) + + +class Qwen3ForCausalLM(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Qwen3Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Qwen3Model( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + + +EntryClass = Qwen3ForCausalLM diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py new file mode 100644 index 000000000..c8786e4e5 --- /dev/null +++ b/python/sglang/srt/models/qwen3_moe.py @@ -0,0 +1,423 @@ +# Adapted from qwen2_moe.py + +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" + +from functools import partial +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP +from sglang.srt.models.qwen2_moe import Qwen2MoeModel +from sglang.srt.utils import add_prefix + +Qwen3MoeConfig = None + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__( + self, + config: Qwen3MoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=add_prefix("experts", prefix), + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=add_prefix("gate", prefix), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Qwen3MoeAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + attention_bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= self.tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % self.tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.tp_rank = get_tensor_model_parallel_rank() + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=add_prefix("attn", prefix), + ) + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3MoeDecoderLayer(nn.Module): + def __init__( + self, + config: Qwen3MoeConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + rms_norm_eps = config.rms_norm_eps + attention_bias = config.attention_bias + self.self_attn = Qwen3MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + + # Note: Qwen/Qwen2-57B-A14B-Instruct does not have + # `mlp_only_layers` in the config. + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) + if (layer_id not in mlp_only_layers) and ( + config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + else: + self.mlp = Qwen3MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen3MoeModel(Qwen2MoeModel): + def __init__( + self, + config: Qwen3MoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + quant_config=quant_config, + prefix=prefix, + decoder_layer_type=Qwen3MoeDecoderLayer, + ) + + +class Qwen3MoeForCausalLM(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: Qwen3MoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +EntryClass = Qwen3MoeForCausalLM