Support Thinking Budget (via custom_logit_processor for OpenAI API) [Fix #6572] (#11416)

Signed-off-by: ybyang <ybyang7@iflytek.com>
Co-authored-by: YorkSu <york_su@qq.com>
This commit is contained in:
ybyang
2025-10-21 16:27:56 +08:00
committed by GitHub
parent c1e1600373
commit dbb16bedd5
7 changed files with 239 additions and 1 deletions

View File

@@ -6,13 +6,17 @@ python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test
"""
import json
import random
import re
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import numpy as np
import openai
import requests
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.utils import kill_process_tree
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
@@ -848,6 +852,94 @@ class TestOpenAIV1Rerank(CustomTestCase):
self.assertTrue(isinstance(response[1]["index"], int))
class TestOpenAIServerCustomLogitProcessor(CustomTestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-custom-logit-processor"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls) -> None:
kill_process_tree(cls.process.pid)
def run_custom_logit_processor(self, target_token_id: Optional[int] = None) -> None:
"""
Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
class DeterministicLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always sample the given token id."""
CUSTOM_PARAM_KEY = "token_id"
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
for i, param_dict in enumerate(custom_param_list):
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, param_dict[self.CUSTOM_PARAM_KEY]] = 0.0
return logits
extra_body = {}
if target_token_id is not None:
extra_body["custom_logit_processor"] = (
DeterministicLogitProcessor().to_str()
)
extra_body["custom_params"] = {
"token_id": target_token_id,
}
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
max_tokens = 200
response = client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": "Question: Is Paris the Capital of France?",
},
],
temperature=0.0,
max_tokens=max_tokens,
extra_body=extra_body,
)
if target_token_id is not None:
target_text = self.tokenizer.decode([target_token_id] * max_tokens)
self.assertTrue(
target_text == response.choices[0].message.content,
f"{target_token_id=}\n{target_text=}\n{response.model_dump(mode='json')}",
)
def test_custom_logit_processor(self) -> None:
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch_mixed(self) -> None:
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
class TestOpenAIV1Score(CustomTestCase):
@classmethod
def setUpClass(cls):