[Feature] Add sampler custom logits processor (#2396)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -5,10 +5,12 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
@@ -24,7 +26,10 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=("--enable-custom-logit-processor",),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -248,6 +253,62 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
self.assertTrue(all(x is not None for x in logprobs))
|
||||
|
||||
def run_custom_logit_processor(self, target_token_id: int):
|
||||
"""Test custom logit processor with custom params."""
|
||||
|
||||
custom_params = {"token_id": target_token_id}
|
||||
|
||||
class DeterministicLogitProcessor(CustomLogitProcessor):
|
||||
"""A dummy logit processor that changes the logits to always
|
||||
sample the given token id.
|
||||
"""
|
||||
|
||||
def __call__(self, logits, custom_param_list):
|
||||
assert logits.shape[0] == len(custom_param_list)
|
||||
key = "token_id"
|
||||
|
||||
for i, param_dict in enumerate(custom_param_list):
|
||||
# Mask all other tokens
|
||||
logits[i, :] = -float("inf")
|
||||
# Assign highest probability to the specified token
|
||||
logits[i, param_dict[key]] = 0.0
|
||||
return logits
|
||||
|
||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||
|
||||
# Base case json data to be posted to the server.
|
||||
base_json = {
|
||||
"text": prompts,
|
||||
"sampling_params": {"temperature": 0.0},
|
||||
"return_logprob": True,
|
||||
}
|
||||
|
||||
# Custom json data with custom logit processor and params.
|
||||
custom_json = base_json.copy()
|
||||
custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
|
||||
custom_json["sampling_params"]["custom_params"] = custom_params
|
||||
|
||||
custom_response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json=custom_json,
|
||||
).json()
|
||||
|
||||
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
|
||||
sampled_tokens = [x[1] for x in output_token_logprobs]
|
||||
|
||||
# The logit processor should always sample the given token as the logits is deterministic.
|
||||
self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
|
||||
|
||||
def test_custom_logit_processor(self):
|
||||
"""Test custom logit processor with a single request."""
|
||||
self.run_custom_logit_processor(target_token_id=5)
|
||||
|
||||
def test_custom_logit_processor_batch(self):
|
||||
"""Test custom logit processor with a batch of requests."""
|
||||
target_token_ids = list(range(32))
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
def test_get_server_info(self):
|
||||
response = requests.get(self.base_url + "/get_server_info")
|
||||
response_json = response.json()
|
||||
|
||||
Reference in New Issue
Block a user