fix probs name which without temp scaling name (#9984)

This commit is contained in:
narutolhy
2025-09-12 21:19:57 -07:00
committed by GitHub
parent cdddab056c
commit 99757cc3e6

View File

@@ -80,9 +80,9 @@ class Sampler(nn.Module):
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
# Post process original logits. if temperatures are all 1.0, no need to rescale
# If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB:
logprobs = torch.softmax(logits, dim=-1)
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
# Post process logits
logits.div_(sampling_info.temperatures)
@@ -123,9 +123,10 @@ class Sampler(nn.Module):
if return_logprob:
# clamp to avoid -inf
if RETURN_ORIGINAL_LOGPROB:
logprobs = torch.log(logprobs).clamp(
min=torch.finfo(logprobs.dtype).min
logprobs = torch.log(probs_without_temp_scaling).clamp(
min=torch.finfo(probs_without_temp_scaling.dtype).min
)
del probs_without_temp_scaling
else:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)