Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: YorkSu <york_su@qq.com>
This commit is contained in:
@@ -6,13 +6,17 @@ python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||
@@ -848,6 +852,94 @@ class TestOpenAIV1Rerank(CustomTestCase):
|
||||
self.assertTrue(isinstance(response[1]["index"], int))
|
||||
|
||||
|
||||
class TestOpenAIServerCustomLogitProcessor(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=["--enable-custom-logit-processor"],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_custom_logit_processor(self, target_token_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
Test custom logit processor with custom params.
|
||||
|
||||
If target_token_id is None, the custom logit processor won't be passed in.
|
||||
"""
|
||||
|
||||
class DeterministicLogitProcessor(CustomLogitProcessor):
|
||||
"""A dummy logit processor that changes the logits to always sample the given token id."""
|
||||
|
||||
CUSTOM_PARAM_KEY = "token_id"
|
||||
|
||||
def __call__(self, logits, custom_param_list):
|
||||
assert logits.shape[0] == len(custom_param_list)
|
||||
|
||||
for i, param_dict in enumerate(custom_param_list):
|
||||
# Mask all other tokens
|
||||
logits[i, :] = -float("inf")
|
||||
# Assign highest probability to the specified token
|
||||
logits[i, param_dict[self.CUSTOM_PARAM_KEY]] = 0.0
|
||||
|
||||
return logits
|
||||
|
||||
extra_body = {}
|
||||
|
||||
if target_token_id is not None:
|
||||
extra_body["custom_logit_processor"] = (
|
||||
DeterministicLogitProcessor().to_str()
|
||||
)
|
||||
extra_body["custom_params"] = {
|
||||
"token_id": target_token_id,
|
||||
}
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
max_tokens = 200
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Question: Is Paris the Capital of France?",
|
||||
},
|
||||
],
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
if target_token_id is not None:
|
||||
target_text = self.tokenizer.decode([target_token_id] * max_tokens)
|
||||
self.assertTrue(
|
||||
target_text == response.choices[0].message.content,
|
||||
f"{target_token_id=}\n{target_text=}\n{response.model_dump(mode='json')}",
|
||||
)
|
||||
|
||||
def test_custom_logit_processor(self) -> None:
|
||||
"""Test custom logit processor with a single request."""
|
||||
self.run_custom_logit_processor(target_token_id=5)
|
||||
|
||||
def test_custom_logit_processor_batch_mixed(self) -> None:
|
||||
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||
target_token_ids = list(range(32)) + [None] * 16
|
||||
random.shuffle(target_token_ids)
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
|
||||
class TestOpenAIV1Score(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user