diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 56a831f2d..565ce106e 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -80,9 +80,9 @@ class Sampler(nn.Module): logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: - # Post process original logits. if temperatures are all 1.0, no need to rescale + # If requested, cache probabilities from original logits before temperature scaling. if return_logprob and RETURN_ORIGINAL_LOGPROB: - logprobs = torch.softmax(logits, dim=-1) + probs_without_temp_scaling = torch.softmax(logits, dim=-1) # Post process logits logits.div_(sampling_info.temperatures) @@ -123,9 +123,10 @@ class Sampler(nn.Module): if return_logprob: # clamp to avoid -inf if RETURN_ORIGINAL_LOGPROB: - logprobs = torch.log(logprobs).clamp( - min=torch.finfo(logprobs.dtype).min + logprobs = torch.log(probs_without_temp_scaling).clamp( + min=torch.finfo(probs_without_temp_scaling.dtype).min ) + del probs_without_temp_scaling else: logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)