From c98e84c21e4313d7d307425ca43e61753a53a9f7 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sun, 6 Oct 2024 13:15:05 -0700 Subject: [PATCH] [Minor, Performance] Use torch.argmax for greedy sampling (#1589) --- python/sglang/srt/layers/sampler.py | 5 +++- test/srt/test_bench_serving.py | 2 +- test/srt/test_pytorch_sampling_backend.py | 29 +++++++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index ad7f0a1f3..b45ec080b 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -43,7 +43,10 @@ class Sampler(nn.Module): torch.isnan(probs), torch.full_like(probs, 1e-10), probs ) - if global_server_args_dict["sampling_backend"] == "flashinfer": + if sampling_info.top_ks.max().item() <= 1: + # Use torch.argmax if all requests use greedy sampling + batch_next_token_ids = torch.argmax(probs, -1) + elif global_server_args_dict["sampling_backend"] == "flashinfer": max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 056483487..6955d4917 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -27,11 +27,11 @@ class TestBenchServing(unittest.TestCase): model=DEFAULT_MODEL_NAME_FOR_TEST, num_prompts=200, request_rate=float("inf"), + other_server_args=["--max-running-requests", "10"], dataset_name="sharegpt", random_input_len=None, random_output_len=None, disable_stream=True, - other_server_args=["--max-running-requests", "10"], ) if is_in_ci(): diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py index ddd744149..5cd121235 100644 --- a/test/srt/test_pytorch_sampling_backend.py +++ b/test/srt/test_pytorch_sampling_backend.py @@ -1,6 +1,9 @@ +import json import unittest from types import SimpleNamespace +import requests + from sglang.srt.utils import kill_child_process from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -39,6 +42,32 @@ class TestPyTorchSamplingBackend(unittest.TestCase): metrics = run_eval(args) assert metrics["score"] >= 0.65 + def test_greedy(self): + response_single = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ).json() + response_batch = requests.post( + self.base_url + "/generate", + json={ + "text": ["The capital of France is"] * 10, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ).json() + text = response_single["text"] + print(text) + for i in range(10): + assert response_batch[i]["text"] == text + if __name__ == "__main__": unittest.main()