diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 20fa13493..5d2e2bc82 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -610,6 +610,12 @@ class LlamaForCausalLM(nn.Module): return self.model.embed_tokens.weight def set_embed(self, embed): + # NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3 + if ( + hasattr(self.config, "target_hidden_size") + and self.config.target_hidden_size != self.config.hidden_size + ): + return del self.model.embed_tokens.weight self.model.embed_tokens.weight = embed torch.cuda.empty_cache() diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index 56342fe24..137a6da56 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -105,7 +105,10 @@ class LlamaModel(nn.Module): prefix=add_prefix("embed_tokens", prefix), ) self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix) - self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size) + if hasattr(config, "target_hidden_size"): + self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size) + else: + self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)