[Enhancement] Custom Logit Processor Improvement (#2998)

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
Hongpeng Guo
2025-01-20 02:00:35 -08:00
committed by GitHub
parent 2584f6d944
commit 583697cd71
6 changed files with 79 additions and 28 deletions

View File

@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
"""
import json
import random
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import numpy as np
import requests
@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase):
self.assertTrue(all(x is not None for x in logprobs))
def run_custom_logit_processor(self, target_token_id: int):
"""Test custom logit processor with custom params."""
def run_custom_logit_processor(self, target_token_id: Optional[int] = None):
"""Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": target_token_id}
@@ -285,8 +290,12 @@ class TestSRTEndpoint(unittest.TestCase):
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
custom_json["sampling_params"]["custom_params"] = custom_params
# Only set the custom logit processor if target_token_id is not None.
if target_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicLogitProcessor().to_str()
)
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase):
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
if target_token_id is not None:
self.assertTrue(
all(x == custom_params["token_id"] for x in sampled_tokens),
# Print the detailed test case info if the test fails.
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
# Temporarily skipped due to buggy implementation
return
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests."""
# Temporarily skipped due to buggy implementation
return
target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()