Refactor logprob computation to return the real logprob used in sampling (#2664)
This commit is contained in:
@@ -232,3 +232,26 @@ class SamplingBatchInfo:
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
# Apply logit_bias
|
||||
if self.logit_bias is not None:
|
||||
logits.add_(self.logit_bias)
|
||||
|
||||
# min-token, presence, frequency
|
||||
if self.linear_penalties is not None:
|
||||
logits.add_(self.linear_penalties)
|
||||
|
||||
# repetition
|
||||
if self.scaling_penalties is not None:
|
||||
logits = torch.where(
|
||||
logits > 0,
|
||||
logits / self.scaling_penalties,
|
||||
logits * self.scaling_penalties,
|
||||
)
|
||||
|
||||
# Apply regex vocab_mask
|
||||
if self.vocab_mask is not None:
|
||||
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
||||
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user