[Minor, Performance] Use torch.argmax for greedy sampling (#1589)
This commit is contained in:
@@ -43,7 +43,10 @@ class Sampler(nn.Module):
|
|||||||
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
if sampling_info.top_ks.max().item() <= 1:
|
||||||
|
# Use torch.argmax if all requests use greedy sampling
|
||||||
|
batch_next_token_ids = torch.argmax(probs, -1)
|
||||||
|
elif global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand(
|
uniform_samples = torch.rand(
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
|
|||||||
@@ -27,11 +27,11 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
model=DEFAULT_MODEL_NAME_FOR_TEST,
|
model=DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
num_prompts=200,
|
num_prompts=200,
|
||||||
request_rate=float("inf"),
|
request_rate=float("inf"),
|
||||||
|
other_server_args=["--max-running-requests", "10"],
|
||||||
dataset_name="sharegpt",
|
dataset_name="sharegpt",
|
||||||
random_input_len=None,
|
random_input_len=None,
|
||||||
random_output_len=None,
|
random_output_len=None,
|
||||||
disable_stream=True,
|
disable_stream=True,
|
||||||
other_server_args=["--max-running-requests", "10"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from sglang.srt.utils import kill_child_process
|
from sglang.srt.utils import kill_child_process
|
||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -39,6 +42,32 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.65
|
assert metrics["score"] >= 0.65
|
||||||
|
|
||||||
|
def test_greedy(self):
|
||||||
|
response_single = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
response_batch = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": ["The capital of France is"] * 10,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
text = response_single["text"]
|
||||||
|
print(text)
|
||||||
|
for i in range(10):
|
||||||
|
assert response_batch[i]["text"] == text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user