Files
sglang/python/sglang/srt/models/qwen3_moe.py
2025-05-16 14:44:10 -07:00

622 lines
23 KiB
Python

# Adapted from qwen2_moe.py
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.utils import add_prefix
Qwen3MoeConfig = None
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}."
)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=add_prefix("gate", prefix),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class Qwen3MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
attention_bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.total_num_heads = num_heads
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.tp_rank = get_tensor_model_parallel_rank()
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=attention_bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=attention_bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
reduce_results=False,
prefix=add_prefix("o_proj", prefix),
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class _FFNInputMode(Enum):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED = auto()
# The MLP sublayer requires all tokens as input
FULL = auto()
@dataclass
class _DecoderLayerInfo:
is_sparse: bool
ffn_input_mode: _FFNInputMode
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.layer_id = layer_id
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
self.info = self._compute_info(config, layer_id=layer_id)
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
self.input_is_scattered = (
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
if self.info.is_sparse:
self.mlp = Qwen3MoeSparseMoeBlock(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
@staticmethod
def _enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1
@staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int):
# WARN: Qwen3MOE has no dense_layer, it is only for compatibility.
mlp_only_layers = (
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
is_sparse = (layer_id not in mlp_only_layers) and (
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
)
ffn_input_mode = (
_FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
else _FFNInputMode.FULL
)
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
)
else:
raise NotImplementedError
def forward_ffn_with_full_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
# TODO extract this bugfix
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
elif hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states)
# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
def forward_ffn_with_scattered_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
# Self Attention
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
if self.attn_tp_size != 1:
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
if self.attn_tp_rank == 0:
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
if not (
self._enable_moe_dense_fully_dp()
and (not self.info.is_sparse)
and hidden_states.shape[0] == 0
):
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
return hidden_states, residual
class Qwen3MoeModel(Qwen2MoeModel):
def __init__(
self,
config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
config=config,
quant_config=quant_config,
prefix=prefix,
decoder_layer_type=Qwen3MoeDecoderLayer,
)
class Qwen3MoeForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen3MoeModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
EntryClass = Qwen3MoeForCausalLM