fix probs name which without temp scaling name (#9984)
This commit is contained in:
@@ -80,9 +80,9 @@ class Sampler(nn.Module):
|
||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
|
||||
else:
|
||||
# Post process original logits. if temperatures are all 1.0, no need to rescale
|
||||
# If requested, cache probabilities from original logits before temperature scaling.
|
||||
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
||||
logprobs = torch.softmax(logits, dim=-1)
|
||||
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
|
||||
|
||||
# Post process logits
|
||||
logits.div_(sampling_info.temperatures)
|
||||
@@ -123,9 +123,10 @@ class Sampler(nn.Module):
|
||||
if return_logprob:
|
||||
# clamp to avoid -inf
|
||||
if RETURN_ORIGINAL_LOGPROB:
|
||||
logprobs = torch.log(logprobs).clamp(
|
||||
min=torch.finfo(logprobs.dtype).min
|
||||
logprobs = torch.log(probs_without_temp_scaling).clamp(
|
||||
min=torch.finfo(probs_without_temp_scaling.dtype).min
|
||||
)
|
||||
del probs_without_temp_scaling
|
||||
else:
|
||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user