Refactor logprob computation to return the real logprob used in sampling (#2664)

This commit is contained in:
Lianmin Zheng
2024-12-30 04:51:38 -08:00
committed by GitHub
parent b02da24a5b
commit 9c6ba2484f
9 changed files with 305 additions and 312 deletions

View File

@@ -232,3 +232,26 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias
if self.logit_bias is not None:
logits.add_(self.logit_bias)
# min-token, presence, frequency
if self.linear_penalties is not None:
logits.add_(self.linear_penalties)
# repetition
if self.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
)
# Apply regex vocab_mask
if self.vocab_mask is not None:
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
return logits