feat: add thinking_budget (#6089)
This commit is contained in:
@@ -64,6 +64,7 @@ Please refer to our dedicated guide on [constrained decoding](./structured_outpu
|
|||||||
| ignore_eos | `bool = False` | Don't stop generation when EOS token is sampled. |
|
| ignore_eos | `bool = False` | Don't stop generation when EOS token is sampled. |
|
||||||
| skip_special_tokens | `bool = True` | Remove special tokens during decoding. |
|
| skip_special_tokens | `bool = True` | Remove special tokens during decoding. |
|
||||||
| custom_params | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below. |
|
| custom_params | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below. |
|
||||||
|
| thinking_budget | `Optional[int] = None` | The maximum number of reasoning tokens that can be generated for a request. |
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
@@ -296,3 +297,29 @@ response = requests.post(
|
|||||||
)
|
)
|
||||||
print(response.json())
|
print(response.json())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Thinking Budget
|
||||||
|
|
||||||
|
Launch a server with `--reasoning-parser`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --model Qwen/Qwen3-8B --reasoning-parser qwen3
|
||||||
|
```
|
||||||
|
|
||||||
|
Send a request:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:30000/generate",
|
||||||
|
json={
|
||||||
|
"text": "9.11 and 9.8, which is greater?",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.3,
|
||||||
|
"max_new_tokens": 256,
|
||||||
|
"thinking_budget": 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
```
|
||||||
|
|||||||
@@ -1145,7 +1145,9 @@ class ModelRunner:
|
|||||||
[self.sample(values, forward_batch) for values in logits_output],
|
[self.sample(values, forward_batch) for values in logits_output],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
|
sampling_info = forward_batch.sampling_info
|
||||||
|
if sampling_info.thinking_budgets is not None:
|
||||||
|
sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
|
||||||
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
||||||
|
|
||||||
# Sample the next tokens
|
# Sample the next tokens
|
||||||
@@ -1156,6 +1158,8 @@ class ModelRunner:
|
|||||||
forward_batch.top_logprobs_nums,
|
forward_batch.top_logprobs_nums,
|
||||||
forward_batch.token_ids_logprobs,
|
forward_batch.token_ids_logprobs,
|
||||||
)
|
)
|
||||||
|
if sampling_info.thinking_budgets is not None:
|
||||||
|
sampling_info.update_thinking_budgets(next_token_ids)
|
||||||
return next_token_ids
|
return next_token_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -529,6 +529,7 @@ def v1_generate_request(
|
|||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
"max_new_tokens": request.max_tokens,
|
"max_new_tokens": request.max_tokens,
|
||||||
"min_new_tokens": request.min_tokens,
|
"min_new_tokens": request.min_tokens,
|
||||||
|
"thinking_budget": request.thinking_budget,
|
||||||
"stop": request.stop,
|
"stop": request.stop,
|
||||||
"stop_token_ids": request.stop_token_ids,
|
"stop_token_ids": request.stop_token_ids,
|
||||||
"top_p": request.top_p,
|
"top_p": request.top_p,
|
||||||
@@ -1101,6 +1102,7 @@ def v1_chat_generate_request(
|
|||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
|
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
|
||||||
"min_new_tokens": request.min_tokens,
|
"min_new_tokens": request.min_tokens,
|
||||||
|
"thinking_budget": request.thinking_budget,
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
"stop_token_ids": request.stop_token_ids,
|
"stop_token_ids": request.stop_token_ids,
|
||||||
"top_p": request.top_p,
|
"top_p": request.top_p,
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ class CompletionRequest(BaseModel):
|
|||||||
top_k: int = -1
|
top_k: int = -1
|
||||||
min_p: float = 0.0
|
min_p: float = 0.0
|
||||||
min_tokens: int = 0
|
min_tokens: int = 0
|
||||||
|
thinking_budget: Optional[int] = None
|
||||||
json_schema: Optional[str] = None
|
json_schema: Optional[str] = None
|
||||||
regex: Optional[str] = None
|
regex: Optional[str] = None
|
||||||
ebnf: Optional[str] = None
|
ebnf: Optional[str] = None
|
||||||
@@ -350,6 +351,13 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
description="The maximum number of completion tokens for a chat completion request, "
|
description="The maximum number of completion tokens for a chat completion request, "
|
||||||
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
|
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
|
||||||
)
|
)
|
||||||
|
thinking_budget: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The maximum number of reasoning tokens that can be generated for a request. "
|
||||||
|
"This setting of does not affect the thinking process of models. "
|
||||||
|
"If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
|
||||||
|
"the reasoning content will be truncated and the final response content will be generated immediately.",
|
||||||
|
)
|
||||||
n: int = 1
|
n: int = 1
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
|
|||||||
One-time parsing: Detects and parses reasoning sections in the provided text.
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
||||||
Returns both reasoning content and normal text separately.
|
Returns both reasoning content and normal text separately.
|
||||||
"""
|
"""
|
||||||
text = text.replace(self.think_start_token, "").strip()
|
text = text.replace(self.think_start_token, "")
|
||||||
if self.think_end_token not in text:
|
if self.think_end_token not in text:
|
||||||
# Assume reasoning was truncated before `</think>` token
|
# Assume reasoning was truncated before `</think>` token
|
||||||
return StreamingParseResult(reasoning_text=text)
|
return StreamingParseResult(reasoning_text=text)
|
||||||
@@ -73,7 +73,7 @@ class BaseReasoningFormatDetector:
|
|||||||
normal_text = current_text[end_idx + len(self.think_end_token) :]
|
normal_text = current_text[end_idx + len(self.think_end_token) :]
|
||||||
|
|
||||||
return StreamingParseResult(
|
return StreamingParseResult(
|
||||||
normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
|
normal_text=normal_text, reasoning_text=reasoning_text
|
||||||
)
|
)
|
||||||
|
|
||||||
# Continue with reasoning content
|
# Continue with reasoning content
|
||||||
|
|||||||
@@ -30,8 +30,13 @@ class SamplingBatchInfo:
|
|||||||
# Whether any request needs min_p sampling
|
# Whether any request needs min_p sampling
|
||||||
need_min_p_sampling: bool
|
need_min_p_sampling: bool
|
||||||
|
|
||||||
|
# Use thinking_budget to truncate thinking
|
||||||
|
num_thinking_tokens: Optional[torch.Tensor] = None
|
||||||
|
think_end_ids: Optional[torch.Tensor] = None
|
||||||
|
thinking_budgets: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Masking tensors for grammar-guided structured outputs
|
# Masking tensors for grammar-guided structured outputs
|
||||||
vocab_size: int
|
vocab_size: int = 0
|
||||||
grammars: Optional[List] = None
|
grammars: Optional[List] = None
|
||||||
vocab_mask: Optional[torch.Tensor] = None
|
vocab_mask: Optional[torch.Tensor] = None
|
||||||
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||||
@@ -76,7 +81,22 @@ class SamplingBatchInfo:
|
|||||||
min_ps = torch.tensor(
|
min_ps = torch.tensor(
|
||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
|
if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
|
||||||
|
think_end_ids = torch.tensor(
|
||||||
|
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
|
||||||
|
dtype=torch.int64,
|
||||||
|
).to(device, non_blocking=True)
|
||||||
|
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
|
||||||
|
device, non_blocking=True
|
||||||
|
)
|
||||||
|
thinking_budgets = torch.tensor(
|
||||||
|
[r.sampling_params.thinking_budget or -1 for r in reqs],
|
||||||
|
dtype=torch.int64,
|
||||||
|
).to(device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
think_end_ids = None
|
||||||
|
num_thinking_tokens = None
|
||||||
|
thinking_budgets = None
|
||||||
# Check if any request has custom logit processor
|
# Check if any request has custom logit processor
|
||||||
has_custom_logit_processor = (
|
has_custom_logit_processor = (
|
||||||
batch.enable_custom_logit_processor # check the flag first.
|
batch.enable_custom_logit_processor # check the flag first.
|
||||||
@@ -132,6 +152,9 @@ class SamplingBatchInfo:
|
|||||||
top_ps=top_ps,
|
top_ps=top_ps,
|
||||||
top_ks=top_ks,
|
top_ks=top_ks,
|
||||||
min_ps=min_ps,
|
min_ps=min_ps,
|
||||||
|
think_end_ids=think_end_ids,
|
||||||
|
num_thinking_tokens=num_thinking_tokens,
|
||||||
|
thinking_budgets=thinking_budgets,
|
||||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
@@ -146,6 +169,35 @@ class SamplingBatchInfo:
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.temperatures)
|
return len(self.temperatures)
|
||||||
|
|
||||||
|
def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
|
||||||
|
has_budget = self.thinking_budgets > 0
|
||||||
|
if not has_budget.any():
|
||||||
|
return
|
||||||
|
torch.where(
|
||||||
|
has_budget,
|
||||||
|
self.num_thinking_tokens + 1,
|
||||||
|
self.num_thinking_tokens,
|
||||||
|
out=self.num_thinking_tokens,
|
||||||
|
)
|
||||||
|
should_stop = has_budget & (
|
||||||
|
self.num_thinking_tokens - 1 > self.thinking_budgets
|
||||||
|
)
|
||||||
|
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
|
||||||
|
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
|
||||||
|
if len(batch_indices) > 0:
|
||||||
|
end_token_indices = self.think_end_ids[batch_indices]
|
||||||
|
next_token_logits[batch_indices, end_token_indices] = 0.0
|
||||||
|
|
||||||
|
def update_thinking_budgets(self, next_token_ids: torch.Tensor):
|
||||||
|
if not torch.any(self.thinking_budgets > 0):
|
||||||
|
return
|
||||||
|
torch.where(
|
||||||
|
next_token_ids == self.think_end_ids,
|
||||||
|
torch.tensor(-1, device=self.thinking_budgets.device),
|
||||||
|
self.thinking_budgets,
|
||||||
|
out=self.thinking_budgets,
|
||||||
|
)
|
||||||
|
|
||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
if not self.grammars:
|
if not self.grammars:
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class SamplingParams:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
|
thinking_budget: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
@@ -57,6 +58,7 @@ class SamplingParams:
|
|||||||
self.stop_token_ids = set(stop_token_ids)
|
self.stop_token_ids = set(stop_token_ids)
|
||||||
else:
|
else:
|
||||||
self.stop_token_ids = None
|
self.stop_token_ids = None
|
||||||
|
self.thinking_budget = thinking_budget
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ suites = {
|
|||||||
TestFile("test_radix_attention.py", 167),
|
TestFile("test_radix_attention.py", 167),
|
||||||
TestFile("test_reasoning_content.py", 89),
|
TestFile("test_reasoning_content.py", 89),
|
||||||
TestFile("test_enable_thinking.py", 70),
|
TestFile("test_enable_thinking.py", 70),
|
||||||
|
TestFile("test_thinking_budget.py", 60),
|
||||||
TestFile("test_regex_constrained.py", 64),
|
TestFile("test_regex_constrained.py", 64),
|
||||||
TestFile("test_release_memory_occupation.py", 44),
|
TestFile("test_release_memory_occupation.py", 44),
|
||||||
TestFile("test_request_length_validation.py", 31),
|
TestFile("test_request_length_validation.py", 31),
|
||||||
|
|||||||
95
test/srt/test_thinking_budget.py
Normal file
95
test/srt/test_thinking_budget.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_20
|
||||||
|
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_200
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestThinkingBudget(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "Qwen/Qwen3-8B"
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-1234"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=[
|
||||||
|
"--reasoning-parser",
|
||||||
|
"qwen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_chat_completion_with_thinking_budget_20(self):
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
|
||||||
|
],
|
||||||
|
"temperature": 0,
|
||||||
|
"separate_reasoning": True,
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
"thinking_budget": 20,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
|
||||||
|
data = response.json()
|
||||||
|
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
|
||||||
|
tokens = self.tokenizer.encode(reasoning_content)
|
||||||
|
self.assertEqual(
|
||||||
|
len(tokens),
|
||||||
|
20,
|
||||||
|
f"Reasoning content length: {len(tokens)} not equal to 20, tokens: {tokens}, reasoning_content: {reasoning_content}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chat_completion_with_thinking_budget_200(self):
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
|
||||||
|
],
|
||||||
|
"temperature": 0,
|
||||||
|
"separate_reasoning": True,
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
"thinking_budget": 200,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
|
||||||
|
data = response.json()
|
||||||
|
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
|
||||||
|
tokens = self.tokenizer.encode(reasoning_content)
|
||||||
|
self.assertEqual(
|
||||||
|
len(tokens),
|
||||||
|
200,
|
||||||
|
f"Reasoning content length {len(tokens)} not equal to 200, tokens: {tokens}, reasoning_content: {reasoning_content}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user