diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 8ed50b1c9..79d13908f 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -86,11 +86,9 @@ class Sampler(nn.Module): # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, # https://github.com/flashinfer-ai/flashinfer/issues/708 # so we use the torch implementation. - - # clamp to avoid -inf - logprobs = torch.log( - top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ).clamp(min=torch.finfo(probs.dtype).min) + # NOTE: OpenAI's logprobs is independent of top-p, we use the + # same rule. + logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) max_top_k_round, batch_size = 32, probs.shape[0] if sampling_info.need_min_p_sampling: @@ -121,10 +119,7 @@ class Sampler(nn.Module): ) if return_logprob: - # clamp to avoid -inf - logprobs = torch.log( - top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ).clamp(min=torch.finfo(probs.dtype).min) + logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) else: raise ValueError( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"