From 681e7af32b7508e2de3846011e61b64cf6f77594 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Sat, 24 May 2025 21:35:55 -0700 Subject: [PATCH] [OAI] Support non-normalized logprobs in OpenAI server (#5961) --- python/sglang/srt/layers/sampler.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 8ed50b1c9..79d13908f 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -86,11 +86,9 @@ class Sampler(nn.Module): # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, # https://github.com/flashinfer-ai/flashinfer/issues/708 # so we use the torch implementation. - - # clamp to avoid -inf - logprobs = torch.log( - top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ).clamp(min=torch.finfo(probs.dtype).min) + # NOTE: OpenAI's logprobs is independent of top-p, we use the + # same rule. + logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) max_top_k_round, batch_size = 32, probs.shape[0] if sampling_info.need_min_p_sampling: @@ -121,10 +119,7 @@ class Sampler(nn.Module): ) if return_logprob: - # clamp to avoid -inf - logprobs = torch.log( - top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ).clamp(min=torch.finfo(probs.dtype).min) + logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) else: raise ValueError( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"