Add Gemma2 (#592)
This commit is contained in:
@@ -108,6 +108,11 @@ class LogitsProcessor(nn.Module):
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
|
||||
if hasattr(self.config, "final_logit_softcapping"):
|
||||
last_logits /= self.config.final_logit_softcapping
|
||||
last_logits = torch.tanh(last_logits)
|
||||
last_logits *= self.config.final_logit_softcapping
|
||||
|
||||
# Return only last_logits if logprob is not requested
|
||||
if not input_metadata.return_logprob:
|
||||
return LogitProcessorOutput(
|
||||
|
||||
Reference in New Issue
Block a user