fix: EXAONE when using tie_word_embeddings (#5759)

This commit is contained in:
Kyungmin Lee
2025-05-22 03:30:04 +09:00
committed by GitHub
parent cfe48c5902
commit ada268fd05

View File

@@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module):
self.transformer = ExaoneModel(
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()