From 99757cc3e68418f48713e69e5543a5783abcd84d Mon Sep 17 00:00:00 2001 From: narutolhy <582909902@qq.com> Date: Fri, 12 Sep 2025 21:19:57 -0700 Subject: [PATCH] fix probs name which without temp scaling name (#9984) --- python/sglang/srt/layers/sampler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 56a831f2d..565ce106e 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -80,9 +80,9 @@ class Sampler(nn.Module): logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: - # Post process original logits. if temperatures are all 1.0, no need to rescale + # If requested, cache probabilities from original logits before temperature scaling. if return_logprob and RETURN_ORIGINAL_LOGPROB: - logprobs = torch.softmax(logits, dim=-1) + probs_without_temp_scaling = torch.softmax(logits, dim=-1) # Post process logits logits.div_(sampling_info.temperatures) @@ -123,9 +123,10 @@ class Sampler(nn.Module): if return_logprob: # clamp to avoid -inf if RETURN_ORIGINAL_LOGPROB: - logprobs = torch.log(logprobs).clamp( - min=torch.finfo(logprobs.dtype).min + logprobs = torch.log(probs_without_temp_scaling).clamp( + min=torch.finfo(probs_without_temp_scaling.dtype).min ) + del probs_without_temp_scaling else: logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)