Move sampling logits to float32 (#773)
This commit is contained in:
@@ -136,7 +136,7 @@ class LogitsProcessor(nn.Module):
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
last_logits = last_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
if hasattr(self.config, "final_logit_softcapping"):
|
||||
last_logits /= self.config.final_logit_softcapping
|
||||
@@ -161,9 +161,9 @@ class LogitsProcessor(nn.Module):
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size]
|
||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
all_logprobs = all_logits.float()
|
||||
all_logprobs = all_logits
|
||||
del all_logits
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user