Add Gemma2 (#592)

This commit is contained in:
Ying Sheng
2024-07-05 09:48:54 -07:00
committed by GitHub
parent d737da5f17
commit 5a57b8addd
7 changed files with 467 additions and 30 deletions

View File

@@ -108,6 +108,11 @@ class LogitsProcessor(nn.Module):
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
if hasattr(self.config, "final_logit_softcapping"):
last_logits /= self.config.final_logit_softcapping
last_logits = torch.tanh(last_logits)
last_logits *= self.config.final_logit_softcapping
# Return only last_logits if logprob is not requested
if not input_metadata.return_logprob:
return LogitProcessorOutput(