fix: Fix returned prefill logits and add output str test (#1046)

This commit is contained in:
Ying Sheng
2024-08-11 23:13:45 -07:00
committed by GitHub
parent fb1f28cbbb
commit 32f6144323
3 changed files with 33 additions and 14 deletions

View File

@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"):
all_logits /= self.config.final_logit_softcapping
all_logits = torch.tanh(all_logits)
all_logits *= self.config.final_logit_softcapping
all_logprobs = all_logits
del all_logits, hidden_states
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)