fix probs name which without temp scaling name (#9984)
This commit is contained in:
@@ -80,9 +80,9 @@ class Sampler(nn.Module):
|
|||||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Post process original logits. if temperatures are all 1.0, no need to rescale
|
# If requested, cache probabilities from original logits before temperature scaling.
|
||||||
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
||||||
logprobs = torch.softmax(logits, dim=-1)
|
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
|
||||||
|
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits.div_(sampling_info.temperatures)
|
logits.div_(sampling_info.temperatures)
|
||||||
@@ -123,9 +123,10 @@ class Sampler(nn.Module):
|
|||||||
if return_logprob:
|
if return_logprob:
|
||||||
# clamp to avoid -inf
|
# clamp to avoid -inf
|
||||||
if RETURN_ORIGINAL_LOGPROB:
|
if RETURN_ORIGINAL_LOGPROB:
|
||||||
logprobs = torch.log(logprobs).clamp(
|
logprobs = torch.log(probs_without_temp_scaling).clamp(
|
||||||
min=torch.finfo(logprobs.dtype).min
|
min=torch.finfo(probs_without_temp_scaling.dtype).min
|
||||||
)
|
)
|
||||||
|
del probs_without_temp_scaling
|
||||||
else:
|
else:
|
||||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user