From 2d54d4bb648dea1bcb13af2ddab30fcf4e1fa37f Mon Sep 17 00:00:00 2001 From: Binyao Jiang Date: Wed, 9 Jul 2025 23:51:33 -0700 Subject: [PATCH] Feat: Support Phi-3.5-MoE in SGLang (#7907) --- docs/supported_models/generative_models.md | 2 +- python/sglang/srt/models/phimoe.py | 553 +++++++++++++++++++++ test/srt/models/test_generation_models.py | 6 + 3 files changed, 560 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/models/phimoe.py diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index 3ffd6a883..5b22f70c9 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -30,7 +30,7 @@ in the GitHub search bar. | **Llama** (2, 3.x, 4 series) | `meta-llama/Llama-4-Scout-17B-16E-Instruct` | Meta’s open LLM series, spanning 7B to 400B parameters (Llama 2, 3, and new Llama 4) with well-recognized performance. [SGLang provides Llama-4 model-specific optimizations](https://docs.sglang.ai/references/llama4) | | **Mistral** (Mixtral, NeMo, Small3) | `mistralai/Mistral-7B-Instruct-v0.2` | Open 7B LLM by Mistral AI with strong performance; extended into MoE (“Mixtral”) and NeMo Megatron variants for larger scale. | | **Gemma** (v1, v2, v3) | `google/gemma-3-1b-it` | Google’s family of efficient multilingual models (1B–27B); Gemma 3 offers a 128K context window, and its larger (4B+) variants support vision input. | -| **Phi** (Phi-3, Phi-4 series) | `microsoft/Phi-4-multimodal-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-mini is a high-accuracy text model and Phi-4-multimodal (5.6B) processes text, images, and speech in one compact model. | +| **Phi** (Phi-3, Phi-4, Phi-MoE series) | `microsoft/Phi-4-multimodal-instruct`, `microsoft/Phi-3.5-MoE-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-mini is a high-accuracy text model, Phi-3.5-MoE is a mixture-of-experts model, and Phi-4-multimodal (5.6B) processes text, images, and speech. | | **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. | | **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. | | **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. | diff --git a/python/sglang/srt/models/phimoe.py b/python/sglang/srt/models/phimoe.py new file mode 100644 index 000000000..22ee023c8 --- /dev/null +++ b/python/sglang/srt/models/phimoe.py @@ -0,0 +1,553 @@ +from typing import Iterable, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer +from sglang.srt.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.utils import add_prefix, make_layers + + +class PhiMoEConfig(PretrainedConfig): + + model_type = "phimoe" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + if head_dim is None: + head_dim = hidden_size // num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def sparsemixer(scores, jitter_eps=0.01): + ################ Select first expert (topk=2) ################ + + # compute mask for sparsity + mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > ( + 2 * jitter_eps + ) + + # apply mask + masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts = max_ind + + # compute scores for gradients + masked_gates = torch.softmax(masked_gates, dim=-1) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + + multiplier = multiplier_o + + # masked out first expert + masked_scores = torch.scatter( + scores, + -1, + selected_experts, + float("-inf"), + ) + + ################ Select second expert (topk=2) ################ + # compute mask for sparsity + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > ( + 2 * jitter_eps + ) + + # apply mask + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts_top2 = max_ind + # compute scores for gradients + masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) + multiplier_top2 = masked_gates_top2.gather(dim=-1, index=selected_experts_top2) + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +def phimoe_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + assert topk == 2, "Only top-2 routing is supported" + assert renormalize is False, "Renormalization is not supported" + + topk_weights, topk_ids = sparsemixer(gating_output) + return topk_weights, topk_ids + + +class PhiMoE(nn.Module): + """A tensor-parallel MoE implementation for PhiMoE that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + quant_config=None, + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + custom_routing_function=phimoe_routing_function, + prefix=add_prefix("experts", prefix), + ) + + def forward( + self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None + ) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class PhiMoEAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + layer_id: int = 0, + attention_bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[dict] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + + self.total_num_heads = num_heads + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= attn_tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % attn_tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) + if head_dim is None: + head_dim = hidden_size // num_heads + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=attention_bias, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + prefix=add_prefix("o_proj", prefix), + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + rope_scaling=self.rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class PhiMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = PhiMoEAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr( + config, "head_dim", self.hidden_size // config.num_attention_heads + ), + rope_theta=rope_theta, + layer_id=layer_id, + attention_bias=config.attention_bias, + quant_config=quant_config, + rope_scaling=config.rope_scaling, + prefix=add_prefix("self_attn", prefix), + ) + self.block_sparse_moe = PhiMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("block_sparse_moe", prefix), + ) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: ForwardBatch, + ) -> Tuple[torch.Tensor, torch.Tensor]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe( + hidden_states, forward_batch=forward_batch + ) + + hidden_states = hidden_states + residual + return hidden_states, residual + + +class PhiMoEModel(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("embed_tokens", prefix), + ) + + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: PhiMoEDecoderLayer( + config, int(prefix.split(".")[-1]), quant_config, prefix=prefix + ), + prefix=add_prefix("layers", prefix), + ) + self.norm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor]: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + + for layer in self.layers: + hidden_states, residual = layer( + positions, hidden_states, residual, forward_batch=forward_batch + ) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +class PhiMoEForCausalLM(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + + super().__init__() + self.config = config + self.quant_config = quant_config + + self.model = PhiMoEModel( + config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + bias=True, + prefix=add_prefix("lm_head", prefix), + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: + hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds) + + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + else: + return self.pooler(hidden_states, forward_batch) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +EntryClass = PhiMoEForCausalLM diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 54ba7b8f2..f30347d46 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -68,6 +68,12 @@ ALL_MODELS = [ ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), + ModelCase( + "microsoft/Phi-3.5-MoE-instruct", + tp_size=2, + trust_remote_code=True, + skip_long_prompt=True, + ), ] TORCH_DTYPES = [torch.float16]