fix: Fix returned prefill logits and add output str test (#1046)
This commit is contained in:
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
if hasattr(self.config, "final_logit_softcapping"):
|
||||
all_logits /= self.config.final_logit_softcapping
|
||||
all_logits = torch.tanh(all_logits)
|
||||
all_logits *= self.config.final_logit_softcapping
|
||||
|
||||
all_logprobs = all_logits
|
||||
del all_logits, hidden_states
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user