From 734daedd8fd9155fa4854b88d3c36cb90831e441 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Fri, 31 Jan 2025 01:04:04 -0800 Subject: [PATCH] [fix] Clamp logprob with dtype min to prevent `-inf` (#3224) --- python/sglang/srt/layers/sampler.py | 7 +++++-- .../penaltylib/test_srt_endpoint_with_penalizers.py | 7 +++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index b24bfc8da..73ef13c35 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -72,9 +72,11 @@ class Sampler(nn.Module): # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, # https://github.com/flashinfer-ai/flashinfer/issues/708 # so we use the torch implementation. + + # clamp to avoid -inf logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ) + ).clamp(min=torch.finfo(probs.dtype).min) max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( @@ -109,9 +111,10 @@ class Sampler(nn.Module): sampling_info.need_min_p_sampling, ) if return_logprob: + # clamp to avoid -inf logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ) + ).clamp(min=torch.finfo(probs.dtype).min) else: raise ValueError( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 34565c9ff..d9d77a9ae 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -36,7 +36,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): def run_decode( self, return_logprob=True, - top_logprobs_num=3, + top_logprobs_num=5, return_text=True, n=1, **sampling_params, @@ -58,8 +58,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) - print("=" * 100) + assert response.status_code == 200, "Request failed: " + response.text def test_default_values(self): self.run_decode() @@ -112,4 +111,4 @@ class TestBatchPenalizerE2E(unittest.TestCase): if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=3)