add support for nvidia/gpt-oss-120b-Eagle3 (#9739)
This commit is contained in:
@@ -185,9 +185,13 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|||||||
)
|
)
|
||||||
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
||||||
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
||||||
|
self.load_lm_head_from_target = False
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head = self.model.embed_tokens
|
self.lm_head = self.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
|
if config.draft_vocab_size is None:
|
||||||
|
self.load_lm_head_from_target = True
|
||||||
|
config.draft_vocab_size = config.vocab_size
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.draft_vocab_size,
|
config.draft_vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
|||||||
@@ -137,8 +137,15 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
|
|
||||||
if self.speculative_algorithm.is_eagle3():
|
if self.speculative_algorithm.is_eagle3():
|
||||||
# EAGLE3 models don't share lm_head
|
# most cases EAGLE3 models don't share lm_head
|
||||||
self.draft_model_runner.model.set_embed(embed)
|
# but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
|
||||||
|
if (
|
||||||
|
hasattr(self.draft_model_runner.model, "load_lm_head_from_target")
|
||||||
|
and self.draft_model_runner.model.load_lm_head_from_target
|
||||||
|
):
|
||||||
|
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||||
|
else:
|
||||||
|
self.draft_model_runner.model.set_embed(embed)
|
||||||
|
|
||||||
# grab hot token ids
|
# grab hot token ids
|
||||||
if self.draft_model_runner.model.hot_token_id is not None:
|
if self.draft_model_runner.model.hot_token_id is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user