fix: EXAONE when using tie_word_embeddings (#5759)
This commit is contained in:
@@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module):
|
||||
self.transformer = ExaoneModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user