Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: YorkSu <york_su@qq.com>
This commit is contained in:
@@ -235,6 +235,44 @@ Important Notes:
|
||||
2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt.
|
||||
|
||||
|
||||
### Thinking Budget for DeepSeek R1
|
||||
|
||||
In SGLang, we can implement thinking budget with `CustomLogitProcessor`.
|
||||
|
||||
Launch a server with `--enable-custom-logit-processor` flag on.
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 --port 30000 --host 0.0.0.0 --mem-fraction-static 0.9 --disable-cuda-graph --reasoning-parser deepseek-r1 --enable-custom-logit-processor
|
||||
```
|
||||
|
||||
Sample Request:
|
||||
|
||||
```python
|
||||
import openai
|
||||
from rich.pretty import pprint
|
||||
from sglang.srt.sampling.custom_logit_processor import DeepSeekR1ThinkingBudgetLogitProcessor
|
||||
|
||||
|
||||
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="*")
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-ai/DeepSeek-R1",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Question: Is Paris the Capital of France?",
|
||||
}
|
||||
],
|
||||
max_tokens=1024,
|
||||
extra_body={
|
||||
"custom_logit_processor": DeepSeekR1ThinkingBudgetLogitProcessor().to_str(),
|
||||
"custom_params": {
|
||||
"thinking_budget": 512,
|
||||
},
|
||||
},
|
||||
)
|
||||
pprint(response)
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
**Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?**
|
||||
|
||||
@@ -319,3 +319,27 @@ response = requests.post(
|
||||
)
|
||||
print(response.json())
|
||||
```
|
||||
|
||||
Send an OpenAI chat completion request:
|
||||
|
||||
```python
|
||||
import openai
|
||||
from sglang.utils import print_highlight
|
||||
|
||||
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
messages=[
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=0.0,
|
||||
max_tokens=32,
|
||||
extra_body={
|
||||
"custom_logit_processor": DeterministicLogitProcessor().to_str(),
|
||||
"custom_params": {"token_id": 5},
|
||||
},
|
||||
)
|
||||
|
||||
print_highlight(f"Response: {response}")
|
||||
```
|
||||
|
||||
@@ -243,6 +243,8 @@ class CompletionRequest(BaseModel):
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||
custom_params: Optional[Dict] = None
|
||||
custom_logit_processor: Optional[str] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
@@ -504,6 +506,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
stream_reasoning: bool = True
|
||||
chat_template_kwargs: Optional[Dict] = None
|
||||
|
||||
# Custom logit processor for advanced sampling control
|
||||
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||
custom_params: Optional[Dict] = None
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
@@ -636,6 +642,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"skip_special_tokens": self.skip_special_tokens,
|
||||
"logit_bias": self.logit_bias,
|
||||
"custom_params": self.custom_params,
|
||||
}
|
||||
|
||||
if self.response_format and self.response_format.type == "json_schema":
|
||||
|
||||
@@ -196,6 +196,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
custom_labels=custom_labels,
|
||||
custom_logit_processor=request.custom_logit_processor,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -121,6 +121,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
custom_labels=custom_labels,
|
||||
custom_logit_processor=request.custom_logit_processor,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
@@ -149,6 +150,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"skip_special_tokens": request.skip_special_tokens,
|
||||
"logit_bias": request.logit_bias,
|
||||
"custom_params": request.custom_params,
|
||||
}
|
||||
|
||||
# Handle response_format constraints
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import dill
|
||||
import orjson
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _cache_from_str(json_str: str):
|
||||
@@ -52,3 +55,74 @@ class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
|
||||
), f"{custom_param_list=}"
|
||||
logits[..., disallowed_token_ids] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
||||
class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
|
||||
"""A logit processor that controls the length of thinking."""
|
||||
|
||||
THINKING_START_TOKEN_ID: int
|
||||
THINKING_END_TOKEN_ID: int
|
||||
NEW_LINE_TOKEN_ID: int
|
||||
|
||||
def __call__(self, logits, custom_param_list: list[dict[str, Any]]):
|
||||
if custom_param_list is None or not custom_param_list:
|
||||
return logits
|
||||
for i, param_dict in enumerate(custom_param_list):
|
||||
if param_dict is None:
|
||||
continue
|
||||
|
||||
thinking_budget: int | None = param_dict.get("thinking_budget")
|
||||
|
||||
# Skip if thinking_budget is unset, or not an integer, or negative
|
||||
if (
|
||||
thinking_budget is None
|
||||
or not isinstance(thinking_budget, int)
|
||||
or thinking_budget < 0
|
||||
):
|
||||
continue
|
||||
req: Req = param_dict.get("__req__")
|
||||
cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids]
|
||||
|
||||
# Check if out of thinking stage
|
||||
if (
|
||||
self.THINKING_START_TOKEN_ID not in cur_ids
|
||||
or self.THINKING_END_TOKEN_ID in cur_ids
|
||||
):
|
||||
continue
|
||||
|
||||
# Find the index of the thinking start token
|
||||
start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
|
||||
|
||||
# Count the number of tokens after the thinking start token
|
||||
num_tokens_after_start = len(cur_ids) - start_index - 1
|
||||
|
||||
if num_tokens_after_start < thinking_budget:
|
||||
continue
|
||||
|
||||
# Ensure new line token before thinking end token
|
||||
if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
|
||||
logits[i, :] = -float("inf")
|
||||
logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
|
||||
continue
|
||||
|
||||
# Assign highest probability to the thinking end token
|
||||
logits[i, :] = -float("inf")
|
||||
logits[i, self.THINKING_END_TOKEN_ID] = 0.0
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
|
||||
"""A logit processor that controls the length of thinking for Qwen3 models."""
|
||||
|
||||
THINKING_START_TOKEN_ID: int = 151667
|
||||
THINKING_END_TOKEN_ID: int = 151668
|
||||
NEW_LINE_TOKEN_ID: int = 198
|
||||
|
||||
|
||||
class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
|
||||
"""A logit processor that controls the length of thinking for DeepSeek-R1 models."""
|
||||
|
||||
THINKING_START_TOKEN_ID: int = 128798
|
||||
THINKING_END_TOKEN_ID: int = 128799
|
||||
NEW_LINE_TOKEN_ID: int = 201
|
||||
|
||||
@@ -6,13 +6,17 @@ python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||
@@ -848,6 +852,94 @@ class TestOpenAIV1Rerank(CustomTestCase):
|
||||
self.assertTrue(isinstance(response[1]["index"], int))
|
||||
|
||||
|
||||
class TestOpenAIServerCustomLogitProcessor(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=["--enable-custom-logit-processor"],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_custom_logit_processor(self, target_token_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
Test custom logit processor with custom params.
|
||||
|
||||
If target_token_id is None, the custom logit processor won't be passed in.
|
||||
"""
|
||||
|
||||
class DeterministicLogitProcessor(CustomLogitProcessor):
|
||||
"""A dummy logit processor that changes the logits to always sample the given token id."""
|
||||
|
||||
CUSTOM_PARAM_KEY = "token_id"
|
||||
|
||||
def __call__(self, logits, custom_param_list):
|
||||
assert logits.shape[0] == len(custom_param_list)
|
||||
|
||||
for i, param_dict in enumerate(custom_param_list):
|
||||
# Mask all other tokens
|
||||
logits[i, :] = -float("inf")
|
||||
# Assign highest probability to the specified token
|
||||
logits[i, param_dict[self.CUSTOM_PARAM_KEY]] = 0.0
|
||||
|
||||
return logits
|
||||
|
||||
extra_body = {}
|
||||
|
||||
if target_token_id is not None:
|
||||
extra_body["custom_logit_processor"] = (
|
||||
DeterministicLogitProcessor().to_str()
|
||||
)
|
||||
extra_body["custom_params"] = {
|
||||
"token_id": target_token_id,
|
||||
}
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
max_tokens = 200
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Question: Is Paris the Capital of France?",
|
||||
},
|
||||
],
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
if target_token_id is not None:
|
||||
target_text = self.tokenizer.decode([target_token_id] * max_tokens)
|
||||
self.assertTrue(
|
||||
target_text == response.choices[0].message.content,
|
||||
f"{target_token_id=}\n{target_text=}\n{response.model_dump(mode='json')}",
|
||||
)
|
||||
|
||||
def test_custom_logit_processor(self) -> None:
|
||||
"""Test custom logit processor with a single request."""
|
||||
self.run_custom_logit_processor(target_token_id=5)
|
||||
|
||||
def test_custom_logit_processor_batch_mixed(self) -> None:
|
||||
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||
target_token_ids = list(range(32)) + [None] * 16
|
||||
random.shuffle(target_token_ids)
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
|
||||
class TestOpenAIV1Score(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user