diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 6c902655d..d0608129a 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -293,6 +293,9 @@ class Qwen2Model(nn.Module): else: self.norm = PPMissingLayer(return_tuple=True) + # For EAGLE3 support + self.layers_to_capture = [] + def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: if hasattr(self.config, "scale_emb"): return self.get_input_embeddings()(input_ids) * self.config.scale_emb @@ -321,7 +324,12 @@ class Qwen2Model(nn.Module): hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] + aux_hidden_states = [] for i in range(self.start_layer, self.end_layer): + if i in self.layers_to_capture: + aux_hidden_states.append( + hidden_states + residual if residual is not None else hidden_states + ) layer = self.layers[i] hidden_states, residual = layer( positions, @@ -342,7 +350,11 @@ class Qwen2Model(nn.Module): hidden_states = self.norm(hidden_states) else: hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 001b3a336..92637d73b 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -440,6 +440,9 @@ class Qwen2MoeModel(nn.Module): else: self.norm = PPMissingLayer(return_tuple=True) + # For EAGLE3 support + self.layers_to_capture = [] + def forward( self, input_ids: torch.Tensor, @@ -459,6 +462,7 @@ class Qwen2MoeModel(nn.Module): hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] + aux_hidden_states = [] if forward_batch.can_run_tbo: hidden_states, residual = model_forward_maybe_tbo( layers=self.layers, @@ -471,6 +475,12 @@ class Qwen2MoeModel(nn.Module): ) else: for i in range(self.start_layer, self.end_layer): + if i in self.layers_to_capture: + aux_hidden_states.append( + hidden_states + residual + if residual is not None + else hidden_states + ) with get_global_expert_distribution_recorder().with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( @@ -489,7 +499,11 @@ class Qwen2MoeModel(nn.Module): hidden_states = self.norm(hidden_states) else: hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states class Qwen2MoeForCausalLM(nn.Module): diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index c5ead114c..9c3659839 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -2,7 +2,7 @@ import logging from functools import partial -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -325,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # For EAGLE3 support + self.capture_aux_hidden_states = False + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -346,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module): pp_proxy_tensors=pp_proxy_tensors, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if self.pp_group.is_last_rank: if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, ) else: return self.pooler(hidden_states, forward_batch) @@ -447,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module): def load_kv_cache_scales(self, quantization_param_path: str) -> None: self.model.load_kv_cache_scales(quantization_param_path) + def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): + if not self.pp_group.is_last_rank: + return + + self.capture_aux_hidden_states = True + if layer_ids is None: + num_layers = self.config.num_hidden_layers + self.model.layers_to_capture = [ + 2, + num_layers // 2, + num_layers - 3, + ] # Specific layers for EAGLE3 support + else: + self.model.layers_to_capture = [val + 1 for val in layer_ids] + EntryClass = Qwen3ForCausalLM diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 9d5ce6103..0ca0a9509 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -18,7 +18,7 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -717,6 +717,7 @@ class Qwen3MoeForCausalLM(nn.Module): use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) + self.capture_aux_hidden_states = False @torch.no_grad() def forward( @@ -735,9 +736,13 @@ class Qwen3MoeForCausalLM(nn.Module): pp_proxy_tensors=pp_proxy_tensors, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if self.pp_group.is_last_rank: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states ) else: return hidden_states @@ -750,6 +755,24 @@ class Qwen3MoeForCausalLM(nn.Module): def end_layer(self): return self.model.end_layer + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): + if not self.pp_group.is_last_rank: + return + + self.capture_aux_hidden_states = True + if layer_ids is None: + num_layers = self.config.num_hidden_layers + self.model.layers_to_capture = [ + 2, + num_layers // 2, + num_layers - 3, + ] # Specific layers for EAGLE3 support + else: + self.model.layers_to_capture = [val + 1 for val in layer_ids] + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id)