Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: YorkSu <york_su@qq.com>
This commit is contained in:
@@ -235,6 +235,44 @@ Important Notes:
|
|||||||
2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt.
|
2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt.
|
||||||
|
|
||||||
|
|
||||||
|
### Thinking Budget for DeepSeek R1
|
||||||
|
|
||||||
|
In SGLang, we can implement thinking budget with `CustomLogitProcessor`.
|
||||||
|
|
||||||
|
Launch a server with `--enable-custom-logit-processor` flag on.
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 --port 30000 --host 0.0.0.0 --mem-fraction-static 0.9 --disable-cuda-graph --reasoning-parser deepseek-r1 --enable-custom-logit-processor
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample Request:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from rich.pretty import pprint
|
||||||
|
from sglang.srt.sampling.custom_logit_processor import DeepSeekR1ThinkingBudgetLogitProcessor
|
||||||
|
|
||||||
|
|
||||||
|
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="*")
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="deepseek-ai/DeepSeek-R1",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Question: Is Paris the Capital of France?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=1024,
|
||||||
|
extra_body={
|
||||||
|
"custom_logit_processor": DeepSeekR1ThinkingBudgetLogitProcessor().to_str(),
|
||||||
|
"custom_params": {
|
||||||
|
"thinking_budget": 512,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
pprint(response)
|
||||||
|
```
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
**Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?**
|
**Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?**
|
||||||
|
|||||||
@@ -319,3 +319,27 @@ response = requests.post(
|
|||||||
)
|
)
|
||||||
print(response.json())
|
print(response.json())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Send an OpenAI chat completion request:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from sglang.utils import print_highlight
|
||||||
|
|
||||||
|
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=32,
|
||||||
|
extra_body={
|
||||||
|
"custom_logit_processor": DeterministicLogitProcessor().to_str(),
|
||||||
|
"custom_params": {"token_id": 5},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print_highlight(f"Response: {response}")
|
||||||
|
```
|
||||||
|
|||||||
@@ -243,6 +243,8 @@ class CompletionRequest(BaseModel):
|
|||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
session_params: Optional[Dict] = None
|
session_params: Optional[Dict] = None
|
||||||
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
||||||
|
custom_params: Optional[Dict] = None
|
||||||
|
custom_logit_processor: Optional[str] = None
|
||||||
|
|
||||||
# For PD disaggregation
|
# For PD disaggregation
|
||||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||||
@@ -504,6 +506,10 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
stream_reasoning: bool = True
|
stream_reasoning: bool = True
|
||||||
chat_template_kwargs: Optional[Dict] = None
|
chat_template_kwargs: Optional[Dict] = None
|
||||||
|
|
||||||
|
# Custom logit processor for advanced sampling control
|
||||||
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||||
|
custom_params: Optional[Dict] = None
|
||||||
|
|
||||||
# For request id
|
# For request id
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Extra key for classifying the request (e.g. cache_salt)
|
# Extra key for classifying the request (e.g. cache_salt)
|
||||||
@@ -636,6 +642,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
"ignore_eos": self.ignore_eos,
|
"ignore_eos": self.ignore_eos,
|
||||||
"skip_special_tokens": self.skip_special_tokens,
|
"skip_special_tokens": self.skip_special_tokens,
|
||||||
"logit_bias": self.logit_bias,
|
"logit_bias": self.logit_bias,
|
||||||
|
"custom_params": self.custom_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.response_format and self.response_format.type == "json_schema":
|
if self.response_format and self.response_format.type == "json_schema":
|
||||||
|
|||||||
@@ -196,6 +196,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
extra_key=self._compute_extra_key(request),
|
extra_key=self._compute_extra_key(request),
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
custom_labels=custom_labels,
|
custom_labels=custom_labels,
|
||||||
|
custom_logit_processor=request.custom_logit_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, request
|
return adapted_request, request
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
extra_key=self._compute_extra_key(request),
|
extra_key=self._compute_extra_key(request),
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
custom_labels=custom_labels,
|
custom_labels=custom_labels,
|
||||||
|
custom_logit_processor=request.custom_logit_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, request
|
return adapted_request, request
|
||||||
@@ -149,6 +150,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
"ignore_eos": request.ignore_eos,
|
"ignore_eos": request.ignore_eos,
|
||||||
"skip_special_tokens": request.skip_special_tokens,
|
"skip_special_tokens": request.skip_special_tokens,
|
||||||
"logit_bias": request.logit_bias,
|
"logit_bias": request.logit_bias,
|
||||||
|
"custom_params": request.custom_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle response_format constraints
|
# Handle response_format constraints
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
import dill
|
import dill
|
||||||
import orjson
|
import orjson
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def _cache_from_str(json_str: str):
|
def _cache_from_str(json_str: str):
|
||||||
@@ -52,3 +55,74 @@ class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
|
|||||||
), f"{custom_param_list=}"
|
), f"{custom_param_list=}"
|
||||||
logits[..., disallowed_token_ids] = -float("inf")
|
logits[..., disallowed_token_ids] = -float("inf")
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
|
||||||
|
"""A logit processor that controls the length of thinking."""
|
||||||
|
|
||||||
|
THINKING_START_TOKEN_ID: int
|
||||||
|
THINKING_END_TOKEN_ID: int
|
||||||
|
NEW_LINE_TOKEN_ID: int
|
||||||
|
|
||||||
|
def __call__(self, logits, custom_param_list: list[dict[str, Any]]):
|
||||||
|
if custom_param_list is None or not custom_param_list:
|
||||||
|
return logits
|
||||||
|
for i, param_dict in enumerate(custom_param_list):
|
||||||
|
if param_dict is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
thinking_budget: int | None = param_dict.get("thinking_budget")
|
||||||
|
|
||||||
|
# Skip if thinking_budget is unset, or not an integer, or negative
|
||||||
|
if (
|
||||||
|
thinking_budget is None
|
||||||
|
or not isinstance(thinking_budget, int)
|
||||||
|
or thinking_budget < 0
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
req: Req = param_dict.get("__req__")
|
||||||
|
cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids]
|
||||||
|
|
||||||
|
# Check if out of thinking stage
|
||||||
|
if (
|
||||||
|
self.THINKING_START_TOKEN_ID not in cur_ids
|
||||||
|
or self.THINKING_END_TOKEN_ID in cur_ids
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find the index of the thinking start token
|
||||||
|
start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
|
||||||
|
|
||||||
|
# Count the number of tokens after the thinking start token
|
||||||
|
num_tokens_after_start = len(cur_ids) - start_index - 1
|
||||||
|
|
||||||
|
if num_tokens_after_start < thinking_budget:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure new line token before thinking end token
|
||||||
|
if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
|
||||||
|
logits[i, :] = -float("inf")
|
||||||
|
logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Assign highest probability to the thinking end token
|
||||||
|
logits[i, :] = -float("inf")
|
||||||
|
logits[i, self.THINKING_END_TOKEN_ID] = 0.0
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
|
||||||
|
"""A logit processor that controls the length of thinking for Qwen3 models."""
|
||||||
|
|
||||||
|
THINKING_START_TOKEN_ID: int = 151667
|
||||||
|
THINKING_END_TOKEN_ID: int = 151668
|
||||||
|
NEW_LINE_TOKEN_ID: int = 198
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
|
||||||
|
"""A logit processor that controls the length of thinking for DeepSeek-R1 models."""
|
||||||
|
|
||||||
|
THINKING_START_TOKEN_ID: int = 128798
|
||||||
|
THINKING_END_TOKEN_ID: int = 128799
|
||||||
|
NEW_LINE_TOKEN_ID: int = 201
|
||||||
|
|||||||
@@ -6,13 +6,17 @@ python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||||
@@ -848,6 +852,94 @@ class TestOpenAIV1Rerank(CustomTestCase):
|
|||||||
self.assertTrue(isinstance(response[1]["index"], int))
|
self.assertTrue(isinstance(response[1]["index"], int))
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServerCustomLogitProcessor(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=["--enable-custom-logit-processor"],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(cls.model)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def run_custom_logit_processor(self, target_token_id: Optional[int] = None) -> None:
|
||||||
|
"""
|
||||||
|
Test custom logit processor with custom params.
|
||||||
|
|
||||||
|
If target_token_id is None, the custom logit processor won't be passed in.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class DeterministicLogitProcessor(CustomLogitProcessor):
|
||||||
|
"""A dummy logit processor that changes the logits to always sample the given token id."""
|
||||||
|
|
||||||
|
CUSTOM_PARAM_KEY = "token_id"
|
||||||
|
|
||||||
|
def __call__(self, logits, custom_param_list):
|
||||||
|
assert logits.shape[0] == len(custom_param_list)
|
||||||
|
|
||||||
|
for i, param_dict in enumerate(custom_param_list):
|
||||||
|
# Mask all other tokens
|
||||||
|
logits[i, :] = -float("inf")
|
||||||
|
# Assign highest probability to the specified token
|
||||||
|
logits[i, param_dict[self.CUSTOM_PARAM_KEY]] = 0.0
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
extra_body = {}
|
||||||
|
|
||||||
|
if target_token_id is not None:
|
||||||
|
extra_body["custom_logit_processor"] = (
|
||||||
|
DeterministicLogitProcessor().to_str()
|
||||||
|
)
|
||||||
|
extra_body["custom_params"] = {
|
||||||
|
"token_id": target_token_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
max_tokens = 200
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Question: Is Paris the Capital of France?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_token_id is not None:
|
||||||
|
target_text = self.tokenizer.decode([target_token_id] * max_tokens)
|
||||||
|
self.assertTrue(
|
||||||
|
target_text == response.choices[0].message.content,
|
||||||
|
f"{target_token_id=}\n{target_text=}\n{response.model_dump(mode='json')}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_custom_logit_processor(self) -> None:
|
||||||
|
"""Test custom logit processor with a single request."""
|
||||||
|
self.run_custom_logit_processor(target_token_id=5)
|
||||||
|
|
||||||
|
def test_custom_logit_processor_batch_mixed(self) -> None:
|
||||||
|
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||||
|
target_token_ids = list(range(32)) + [None] * 16
|
||||||
|
random.shuffle(target_token_ids)
|
||||||
|
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||||
|
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIV1Score(CustomTestCase):
|
class TestOpenAIV1Score(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user