diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 430c1d58b..1e4dfb3df 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module): self.transformer = ExaoneModel( config, quant_config=quant_config, prefix=add_prefix("transformer", prefix) ) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) - ) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=add_prefix("lm_head", prefix), + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad()