diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index f8d7b608c..5e632d5e4 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -185,9 +185,13 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM): ) # Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.1 8B Instruct set tie_word_embeddings to False + self.load_lm_head_from_target = False if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: + if config.draft_vocab_size is None: + self.load_lm_head_from_target = True + config.draft_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( config.draft_vocab_size, config.hidden_size, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 4829fc83e..5a9454cd2 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -137,8 +137,15 @@ class EAGLEWorker(TpModelWorker): embed, head = self.target_worker.model_runner.model.get_embed_and_head() if self.speculative_algorithm.is_eagle3(): - # EAGLE3 models don't share lm_head - self.draft_model_runner.model.set_embed(embed) + # most cases EAGLE3 models don't share lm_head + # but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares + if ( + hasattr(self.draft_model_runner.model, "load_lm_head_from_target") + and self.draft_model_runner.model.load_lm_head_from_target + ): + self.draft_model_runner.model.set_embed_and_head(embed, head) + else: + self.draft_model_runner.model.set_embed(embed) # grab hot token ids if self.draft_model_runner.model.hot_token_id is not None: