[Enhancement] Custom Logit Processor Improvement (#2998)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
self.assertTrue(all(x is not None for x in logprobs))
|
||||
|
||||
def run_custom_logit_processor(self, target_token_id: int):
|
||||
"""Test custom logit processor with custom params."""
|
||||
def run_custom_logit_processor(self, target_token_id: Optional[int] = None):
|
||||
"""Test custom logit processor with custom params.
|
||||
|
||||
If target_token_id is None, the custom logit processor won't be passed in.
|
||||
"""
|
||||
|
||||
custom_params = {"token_id": target_token_id}
|
||||
|
||||
@@ -285,8 +290,12 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
# Custom json data with custom logit processor and params.
|
||||
custom_json = base_json.copy()
|
||||
custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
|
||||
custom_json["sampling_params"]["custom_params"] = custom_params
|
||||
# Only set the custom logit processor if target_token_id is not None.
|
||||
if target_token_id is not None:
|
||||
custom_json["custom_logit_processor"] = (
|
||||
DeterministicLogitProcessor().to_str()
|
||||
)
|
||||
custom_json["sampling_params"]["custom_params"] = custom_params
|
||||
|
||||
custom_response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
sampled_tokens = [x[1] for x in output_token_logprobs]
|
||||
|
||||
# The logit processor should always sample the given token as the logits is deterministic.
|
||||
self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
|
||||
if target_token_id is not None:
|
||||
self.assertTrue(
|
||||
all(x == custom_params["token_id"] for x in sampled_tokens),
|
||||
# Print the detailed test case info if the test fails.
|
||||
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
||||
)
|
||||
|
||||
def test_custom_logit_processor(self):
|
||||
"""Test custom logit processor with a single request."""
|
||||
# Temporarily skipped due to buggy implementation
|
||||
return
|
||||
self.run_custom_logit_processor(target_token_id=5)
|
||||
|
||||
def test_custom_logit_processor_batch(self):
|
||||
"""Test custom logit processor with a batch of requests."""
|
||||
# Temporarily skipped due to buggy implementation
|
||||
return
|
||||
target_token_ids = list(range(32))
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
def test_custom_logit_processor_batch_mixed(self):
|
||||
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||
target_token_ids = list(range(32)) + [None] * 16
|
||||
random.shuffle(target_token_ids)
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
def test_get_server_info(self):
|
||||
response = requests.get(self.base_url + "/get_server_info")
|
||||
response_json = response.json()
|
||||
|
||||
Reference in New Issue
Block a user