diff --git a/docs/backend/sampling_params.md b/docs/backend/sampling_params.md index 9423ab06d..736d67e05 100644 --- a/docs/backend/sampling_params.md +++ b/docs/backend/sampling_params.md @@ -64,7 +64,6 @@ Please refer to our dedicated guide on [constrained decoding](./structured_outpu | ignore_eos | `bool = False` | Don't stop generation when EOS token is sampled. | | skip_special_tokens | `bool = True` | Remove special tokens during decoding. | | custom_params | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below. | -| thinking_budget | `Optional[int] = None` | The maximum number of reasoning tokens that can be generated for a request. | ## Examples @@ -297,29 +296,3 @@ response = requests.post( ) print(response.json()) ``` - -### Thinking Budget - -Launch a server with `--reasoning-parser`. - -```bash -python3 -m sglang.launch_server --model Qwen/Qwen3-8B --reasoning-parser qwen3 -``` - -Send a request: - -```python -import requests -response = requests.post( - "http://localhost:30000/generate", - json={ - "text": "9.11 and 9.8, which is greater?", - "sampling_params": { - "temperature": 0.3, - "max_new_tokens": 256, - "thinking_budget": 20, - }, - }, -) -print(response.json()) -``` diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 272154b33..deabf8265 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1145,9 +1145,7 @@ class ModelRunner: [self.sample(values, forward_batch) for values in logits_output], axis=-1, ) - sampling_info = forward_batch.sampling_info - if sampling_info.thinking_budgets is not None: - sampling_info.apply_thinking_budgets(logits_output.next_token_logits) + self._preprocess_logits(logits_output, forward_batch.sampling_info) # Sample the next tokens @@ -1158,8 +1156,6 @@ class ModelRunner: forward_batch.top_logprobs_nums, forward_batch.token_ids_logprobs, ) - if sampling_info.thinking_budgets is not None: - sampling_info.update_thinking_budgets(next_token_ids) return next_token_ids @property diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index d1db059ce..ba10f2951 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -529,7 +529,6 @@ def v1_generate_request( "temperature": request.temperature, "max_new_tokens": request.max_tokens, "min_new_tokens": request.min_tokens, - "thinking_budget": request.thinking_budget, "stop": request.stop, "stop_token_ids": request.stop_token_ids, "top_p": request.top_p, @@ -1102,7 +1101,6 @@ def v1_chat_generate_request( "temperature": request.temperature, "max_new_tokens": request.max_tokens or request.max_completion_tokens, "min_new_tokens": request.min_tokens, - "thinking_budget": request.thinking_budget, "stop": stop, "stop_token_ids": request.stop_token_ids, "top_p": request.top_p, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index d4cd845b3..c37442248 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -172,7 +172,6 @@ class CompletionRequest(BaseModel): top_k: int = -1 min_p: float = 0.0 min_tokens: int = 0 - thinking_budget: Optional[int] = None json_schema: Optional[str] = None regex: Optional[str] = None ebnf: Optional[str] = None @@ -351,13 +350,6 @@ class ChatCompletionRequest(BaseModel): description="The maximum number of completion tokens for a chat completion request, " "including visible output tokens and reasoning tokens. Input tokens are not included. ", ) - thinking_budget: Optional[int] = Field( - default=None, - description="The maximum number of reasoning tokens that can be generated for a request. " - "This setting of does not affect the thinking process of models. " - "If the number of tokens generated by the model's thinking process exceeds thinking_budget, " - "the reasoning content will be truncated and the final response content will be generated immediately.", - ) n: int = 1 presence_penalty: float = 0.0 response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py index bec184273..977e26d3e 100644 --- a/python/sglang/srt/reasoning_parser.py +++ b/python/sglang/srt/reasoning_parser.py @@ -32,7 +32,7 @@ class BaseReasoningFormatDetector: One-time parsing: Detects and parses reasoning sections in the provided text. Returns both reasoning content and normal text separately. """ - text = text.replace(self.think_start_token, "") + text = text.replace(self.think_start_token, "").strip() if self.think_end_token not in text: # Assume reasoning was truncated before `` token return StreamingParseResult(reasoning_text=text) @@ -73,7 +73,7 @@ class BaseReasoningFormatDetector: normal_text = current_text[end_idx + len(self.think_end_token) :] return StreamingParseResult( - normal_text=normal_text, reasoning_text=reasoning_text + normal_text=normal_text, reasoning_text=reasoning_text.rstrip() ) # Continue with reasoning content diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index c1a6abdb6..66e6552c0 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -30,13 +30,8 @@ class SamplingBatchInfo: # Whether any request needs min_p sampling need_min_p_sampling: bool - # Use thinking_budget to truncate thinking - num_thinking_tokens: Optional[torch.Tensor] = None - think_end_ids: Optional[torch.Tensor] = None - thinking_budgets: Optional[torch.Tensor] = None - # Masking tensors for grammar-guided structured outputs - vocab_size: int = 0 + vocab_size: int grammars: Optional[List] = None vocab_mask: Optional[torch.Tensor] = None apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None @@ -81,22 +76,7 @@ class SamplingBatchInfo: min_ps = torch.tensor( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) - if any(hasattr(r.tokenizer, "think_end_id") for r in reqs): - think_end_ids = torch.tensor( - [getattr(r.tokenizer, "think_end_id", -1) for r in reqs], - dtype=torch.int64, - ).to(device, non_blocking=True) - num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to( - device, non_blocking=True - ) - thinking_budgets = torch.tensor( - [r.sampling_params.thinking_budget or -1 for r in reqs], - dtype=torch.int64, - ).to(device, non_blocking=True) - else: - think_end_ids = None - num_thinking_tokens = None - thinking_budgets = None + # Check if any request has custom logit processor has_custom_logit_processor = ( batch.enable_custom_logit_processor # check the flag first. @@ -152,9 +132,6 @@ class SamplingBatchInfo: top_ps=top_ps, top_ks=top_ks, min_ps=min_ps, - think_end_ids=think_end_ids, - num_thinking_tokens=num_thinking_tokens, - thinking_budgets=thinking_budgets, is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), vocab_size=vocab_size, @@ -169,35 +146,6 @@ class SamplingBatchInfo: def __len__(self): return len(self.temperatures) - def apply_thinking_budgets(self, next_token_logits: torch.Tensor): - has_budget = self.thinking_budgets > 0 - if not has_budget.any(): - return - torch.where( - has_budget, - self.num_thinking_tokens + 1, - self.num_thinking_tokens, - out=self.num_thinking_tokens, - ) - should_stop = has_budget & ( - self.num_thinking_tokens - 1 > self.thinking_budgets - ) - next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf")) - batch_indices = torch.nonzero(should_stop, as_tuple=True)[0] - if len(batch_indices) > 0: - end_token_indices = self.think_end_ids[batch_indices] - next_token_logits[batch_indices, end_token_indices] = 0.0 - - def update_thinking_budgets(self, next_token_ids: torch.Tensor): - if not torch.any(self.thinking_budgets > 0): - return - torch.where( - next_token_ids == self.think_end_ids, - torch.tensor(-1, device=self.thinking_budgets.device), - self.thinking_budgets, - out=self.thinking_budgets, - ) - def update_regex_vocab_mask(self): if not self.grammars: self.vocab_mask = None diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 03c9d6202..7c77a204f 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -30,7 +30,6 @@ class SamplingParams: def __init__( self, max_new_tokens: int = 128, - thinking_budget: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: float = 1.0, @@ -58,7 +57,6 @@ class SamplingParams: self.stop_token_ids = set(stop_token_ids) else: self.stop_token_ids = None - self.thinking_budget = thinking_budget self.temperature = temperature self.top_p = top_p self.top_k = top_k diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 82a9adac1..59018e343 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -61,7 +61,6 @@ suites = { TestFile("test_radix_attention.py", 167), TestFile("test_reasoning_content.py", 89), TestFile("test_enable_thinking.py", 70), - TestFile("test_thinking_budget.py", 60), TestFile("test_regex_constrained.py", 64), TestFile("test_release_memory_occupation.py", 44), TestFile("test_request_length_validation.py", 31), diff --git a/test/srt/test_thinking_budget.py b/test/srt/test_thinking_budget.py deleted file mode 100644 index 9d264c9c6..000000000 --- a/test/srt/test_thinking_budget.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -Usage: -python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_20 -python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_200 -""" - -import unittest - -import requests -from transformers import AutoTokenizer - -from sglang.srt.utils import kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - - -class TestThinkingBudget(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = "Qwen/Qwen3-8B" - cls.tokenizer = AutoTokenizer.from_pretrained(cls.model) - cls.base_url = DEFAULT_URL_FOR_TEST - cls.api_key = "sk-1234" - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - api_key=cls.api_key, - other_args=[ - "--reasoning-parser", - "qwen3", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_chat_completion_with_thinking_budget_20(self): - response = requests.post( - f"{self.base_url}/v1/chat/completions", - headers={"Authorization": f"Bearer {self.api_key}"}, - json={ - "model": self.model, - "messages": [ - {"role": "user", "content": "9.11 and 9.8, which is greater?"} - ], - "temperature": 0, - "separate_reasoning": True, - "chat_template_kwargs": {"enable_thinking": True}, - "thinking_budget": 20, - }, - ) - self.assertEqual(response.status_code, 200, f"Failed with: {response.text}") - data = response.json() - reasoning_content = data["choices"][0]["message"]["reasoning_content"] - tokens = self.tokenizer.encode(reasoning_content) - self.assertEqual( - len(tokens), - 20, - f"Reasoning content length: {len(tokens)} not equal to 20, tokens: {tokens}, reasoning_content: {reasoning_content}", - ) - - def test_chat_completion_with_thinking_budget_200(self): - response = requests.post( - f"{self.base_url}/v1/chat/completions", - headers={"Authorization": f"Bearer {self.api_key}"}, - json={ - "model": self.model, - "messages": [ - {"role": "user", "content": "9.11 and 9.8, which is greater?"} - ], - "temperature": 0, - "separate_reasoning": True, - "chat_template_kwargs": {"enable_thinking": True}, - "thinking_budget": 200, - }, - ) - self.assertEqual(response.status_code, 200, f"Failed with: {response.text}") - data = response.json() - reasoning_content = data["choices"][0]["message"]["reasoning_content"] - tokens = self.tokenizer.encode(reasoning_content) - self.assertEqual( - len(tokens), - 200, - f"Reasoning content length {len(tokens)} not equal to 200, tokens: {tokens}, reasoning_content: {reasoning_content}", - ) - - -if __name__ == "__main__": - unittest.main()