[Feature] Add sampler custom logits processor (#2396)

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
Hongpeng Guo
2025-01-19 14:46:53 -08:00
committed by GitHub
parent 3bcf5ecea7
commit e403d23757
12 changed files with 302 additions and 4 deletions

View File

@@ -5,10 +5,12 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
import json
import unittest
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import requests
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
@@ -24,7 +26,10 @@ class TestSRTEndpoint(unittest.TestCase):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=("--enable-custom-logit-processor",),
)
@classmethod
@@ -248,6 +253,62 @@ class TestSRTEndpoint(unittest.TestCase):
self.assertTrue(all(x is not None for x in logprobs))
def run_custom_logit_processor(self, target_token_id: int):
"""Test custom logit processor with custom params."""
custom_params = {"token_id": target_token_id}
class DeterministicLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
key = "token_id"
for i, param_dict in enumerate(custom_param_list):
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, param_dict[key]] = 0.0
return logits
prompts = "Question: Is Paris the Capital of France? Answer:"
# Base case json data to be posted to the server.
base_json = {
"text": prompts,
"sampling_params": {"temperature": 0.0},
"return_logprob": True,
}
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
json=custom_json,
).json()
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests."""
target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()