diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 2b1ea57fd..531f5b6e9 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -16,7 +16,7 @@ # Modify details for the adaptation of Qwen2 model. """Inference-only Qwen2 model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module): quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) - else: # ranks other than the last rank will have a placeholder layer self.lm_head = PPMissingLayer() @@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # For EAGLE3 support + self.capture_aux_hidden_states = False def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embedding(input_ids) @@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module): input_embeds, pp_proxy_tensors=pp_proxy_tensors, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states if self.pp_group.is_last_rank: if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, ) else: return self.pooler(hidden_states, forward_batch) @@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module): def load_kv_cache_scales(self, quantization_param_path: str) -> None: self.model.load_kv_cache_scales(quantization_param_path) + def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): + if not self.pp_group.is_last_rank: + return + + self.capture_aux_hidden_states = True + if layer_ids is None: + num_layers = self.config.num_hidden_layers + self.model.layers_to_capture = [ + 2, + num_layers // 2, + num_layers - 3, + ] # Specific layers for EAGLE3 support + else: + self.model.layers_to_capture = [val + 1 for val in layer_ids] + EntryClass = Qwen2ForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index a3427e068..56ac79a7f 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -17,7 +17,7 @@ """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module): use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) + # For EAGLE3 support + self.capture_aux_hidden_states = False @torch.no_grad() def forward( @@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module): input_embeds, pp_proxy_tensors=pp_proxy_tensors, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states if self.pp_group.is_last_rank: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states ) else: return hidden_states @@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module): num_groups=None, ) + def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): + if not self.pp_group.is_last_rank: + return + + self.capture_aux_hidden_states = True + if layer_ids is None: + num_layers = self.config.num_hidden_layers + self.model.layers_to_capture = [ + 2, + num_layers // 2, + num_layers - 3, + ] # Specific layers for EAGLE3 support + else: + self.model.layers_to_capture = [val + 1 for val in layer_ids] + EntryClass = Qwen2MoeForCausalLM