diff --git a/docs/basic_usage/deepseek.md b/docs/basic_usage/deepseek.md index 7e5daa898..96f43ab0a 100644 --- a/docs/basic_usage/deepseek.md +++ b/docs/basic_usage/deepseek.md @@ -235,6 +235,44 @@ Important Notes: 2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt. +### Thinking Budget for DeepSeek R1 + +In SGLang, we can implement thinking budget with `CustomLogitProcessor`. + +Launch a server with `--enable-custom-logit-processor` flag on. + +``` +python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 --port 30000 --host 0.0.0.0 --mem-fraction-static 0.9 --disable-cuda-graph --reasoning-parser deepseek-r1 --enable-custom-logit-processor +``` + +Sample Request: + +```python +import openai +from rich.pretty import pprint +from sglang.srt.sampling.custom_logit_processor import DeepSeekR1ThinkingBudgetLogitProcessor + + +client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="*") +response = client.chat.completions.create( + model="deepseek-ai/DeepSeek-R1", + messages=[ + { + "role": "user", + "content": "Question: Is Paris the Capital of France?", + } + ], + max_tokens=1024, + extra_body={ + "custom_logit_processor": DeepSeekR1ThinkingBudgetLogitProcessor().to_str(), + "custom_params": { + "thinking_budget": 512, + }, + }, +) +pprint(response) +``` + ## FAQ **Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?** diff --git a/docs/basic_usage/sampling_params.md b/docs/basic_usage/sampling_params.md index f6faf72d9..8b1035b81 100644 --- a/docs/basic_usage/sampling_params.md +++ b/docs/basic_usage/sampling_params.md @@ -319,3 +319,27 @@ response = requests.post( ) print(response.json()) ``` + +Send an OpenAI chat completion request: + +```python +import openai +from sglang.utils import print_highlight + +client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None") + +response = client.chat.completions.create( + model="meta-llama/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.0, + max_tokens=32, + extra_body={ + "custom_logit_processor": DeterministicLogitProcessor().to_str(), + "custom_params": {"token_id": 5}, + }, +) + +print_highlight(f"Response: {response}") +``` diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 638e97978..5c1cdd4f5 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -243,6 +243,8 @@ class CompletionRequest(BaseModel): lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None + custom_params: Optional[Dict] = None + custom_logit_processor: Optional[str] = None # For PD disaggregation bootstrap_host: Optional[Union[List[str], str]] = None @@ -504,6 +506,10 @@ class ChatCompletionRequest(BaseModel): stream_reasoning: bool = True chat_template_kwargs: Optional[Dict] = None + # Custom logit processor for advanced sampling control + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None + custom_params: Optional[Dict] = None + # For request id rid: Optional[Union[List[str], str]] = None # Extra key for classifying the request (e.g. cache_salt) @@ -636,6 +642,7 @@ class ChatCompletionRequest(BaseModel): "ignore_eos": self.ignore_eos, "skip_special_tokens": self.skip_special_tokens, "logit_bias": self.logit_bias, + "custom_params": self.custom_params, } if self.response_format and self.response_format.type == "json_schema": diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 9529d3dbd..ba43a21e1 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -196,6 +196,7 @@ class OpenAIServingChat(OpenAIServingBase): extra_key=self._compute_extra_key(request), priority=request.priority, custom_labels=custom_labels, + custom_logit_processor=request.custom_logit_processor, ) return adapted_request, request diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index b6c8d7432..1620dfa76 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -121,6 +121,7 @@ class OpenAIServingCompletion(OpenAIServingBase): extra_key=self._compute_extra_key(request), priority=request.priority, custom_labels=custom_labels, + custom_logit_processor=request.custom_logit_processor, ) return adapted_request, request @@ -149,6 +150,7 @@ class OpenAIServingCompletion(OpenAIServingBase): "ignore_eos": request.ignore_eos, "skip_special_tokens": request.skip_special_tokens, "logit_bias": request.logit_bias, + "custom_params": request.custom_params, } # Handle response_format constraints diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index 80820c361..edd91afd6 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -1,12 +1,15 @@ import json from abc import ABC, abstractmethod from functools import lru_cache -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import dill import orjson import torch +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + @lru_cache(maxsize=None) def _cache_from_str(json_str: str): @@ -52,3 +55,74 @@ class DisallowedTokensLogitsProcessor(CustomLogitProcessor): ), f"{custom_param_list=}" logits[..., disallowed_token_ids] = -float("inf") return logits + + +class ThinkingBudgetLogitProcessor(CustomLogitProcessor): + """A logit processor that controls the length of thinking.""" + + THINKING_START_TOKEN_ID: int + THINKING_END_TOKEN_ID: int + NEW_LINE_TOKEN_ID: int + + def __call__(self, logits, custom_param_list: list[dict[str, Any]]): + if custom_param_list is None or not custom_param_list: + return logits + for i, param_dict in enumerate(custom_param_list): + if param_dict is None: + continue + + thinking_budget: int | None = param_dict.get("thinking_budget") + + # Skip if thinking_budget is unset, or not an integer, or negative + if ( + thinking_budget is None + or not isinstance(thinking_budget, int) + or thinking_budget < 0 + ): + continue + req: Req = param_dict.get("__req__") + cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids] + + # Check if out of thinking stage + if ( + self.THINKING_START_TOKEN_ID not in cur_ids + or self.THINKING_END_TOKEN_ID in cur_ids + ): + continue + + # Find the index of the thinking start token + start_index = cur_ids.index(self.THINKING_START_TOKEN_ID) + + # Count the number of tokens after the thinking start token + num_tokens_after_start = len(cur_ids) - start_index - 1 + + if num_tokens_after_start < thinking_budget: + continue + + # Ensure new line token before thinking end token + if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID: + logits[i, :] = -float("inf") + logits[i, self.NEW_LINE_TOKEN_ID] = 0.0 + continue + + # Assign highest probability to the thinking end token + logits[i, :] = -float("inf") + logits[i, self.THINKING_END_TOKEN_ID] = 0.0 + + return logits + + +class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): + """A logit processor that controls the length of thinking for Qwen3 models.""" + + THINKING_START_TOKEN_ID: int = 151667 + THINKING_END_TOKEN_ID: int = 151668 + NEW_LINE_TOKEN_ID: int = 198 + + +class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): + """A logit processor that controls the length of thinking for DeepSeek-R1 models.""" + + THINKING_START_TOKEN_ID: int = 128798 + THINKING_END_TOKEN_ID: int = 128799 + NEW_LINE_TOKEN_ID: int = 201 diff --git a/test/srt/openai_server/basic/test_openai_server.py b/test/srt/openai_server/basic/test_openai_server.py index 96251f2cd..c32ef9e90 100644 --- a/test/srt/openai_server/basic/test_openai_server.py +++ b/test/srt/openai_server/basic/test_openai_server.py @@ -6,13 +6,17 @@ python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test """ import json +import random import re import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Optional import numpy as np import openai import requests +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.utils import kill_process_tree from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.runners import TEST_RERANK_QUERY_DOCS @@ -848,6 +852,94 @@ class TestOpenAIV1Rerank(CustomTestCase): self.assertTrue(isinstance(response[1]["index"], int)) +class TestOpenAIServerCustomLogitProcessor(CustomTestCase): + @classmethod + def setUpClass(cls) -> None: + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=["--enable-custom-logit-processor"], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(cls.model) + + @classmethod + def tearDownClass(cls) -> None: + kill_process_tree(cls.process.pid) + + def run_custom_logit_processor(self, target_token_id: Optional[int] = None) -> None: + """ + Test custom logit processor with custom params. + + If target_token_id is None, the custom logit processor won't be passed in. + """ + + class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always sample the given token id.""" + + CUSTOM_PARAM_KEY = "token_id" + + def __call__(self, logits, custom_param_list): + assert logits.shape[0] == len(custom_param_list) + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[self.CUSTOM_PARAM_KEY]] = 0.0 + + return logits + + extra_body = {} + + if target_token_id is not None: + extra_body["custom_logit_processor"] = ( + DeterministicLogitProcessor().to_str() + ) + extra_body["custom_params"] = { + "token_id": target_token_id, + } + + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + max_tokens = 200 + + response = client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": "Question: Is Paris the Capital of France?", + }, + ], + temperature=0.0, + max_tokens=max_tokens, + extra_body=extra_body, + ) + + if target_token_id is not None: + target_text = self.tokenizer.decode([target_token_id] * max_tokens) + self.assertTrue( + target_text == response.choices[0].message.content, + f"{target_token_id=}\n{target_text=}\n{response.model_dump(mode='json')}", + ) + + def test_custom_logit_processor(self) -> None: + """Test custom logit processor with a single request.""" + self.run_custom_logit_processor(target_token_id=5) + + def test_custom_logit_processor_batch_mixed(self) -> None: + """Test a batch of requests mixed of requests with and without custom logit processor.""" + target_token_ids = list(range(32)) + [None] * 16 + random.shuffle(target_token_ids) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list(executor.map(self.run_custom_logit_processor, target_token_ids)) + + class TestOpenAIV1Score(CustomTestCase): @classmethod def setUpClass(cls):