[oai serving chat] Add argument --sampling-defaults and fix ChatCompletionRequest defaults (#11304)
This commit is contained in:
@@ -17,7 +17,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from enum import Enum, IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from typing import Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@@ -90,6 +90,7 @@ class ModelConfig:
|
|||||||
is_draft_model: bool = False,
|
is_draft_model: bool = False,
|
||||||
hybrid_kvcache_ratio: Optional[float] = None,
|
hybrid_kvcache_ratio: Optional[float] = None,
|
||||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||||
|
sampling_defaults: str = "openai",
|
||||||
) -> None:
|
) -> None:
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
@@ -98,6 +99,7 @@ class ModelConfig:
|
|||||||
self.modelopt_quant = modelopt_quant
|
self.modelopt_quant = modelopt_quant
|
||||||
self.is_draft_model = is_draft_model
|
self.is_draft_model = is_draft_model
|
||||||
self.model_impl = model_impl
|
self.model_impl = model_impl
|
||||||
|
self.sampling_defaults = sampling_defaults
|
||||||
|
|
||||||
# Get hf config
|
# Get hf config
|
||||||
self._maybe_pull_model_tokenizer_from_remote()
|
self._maybe_pull_model_tokenizer_from_remote()
|
||||||
@@ -214,6 +216,7 @@ class ModelConfig:
|
|||||||
modelopt_quant=server_args.modelopt_quant,
|
modelopt_quant=server_args.modelopt_quant,
|
||||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||||
model_impl=server_args.model_impl,
|
model_impl=server_args.model_impl,
|
||||||
|
sampling_defaults=server_args.sampling_defaults,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -659,6 +662,38 @@ class ModelConfig:
|
|||||||
eos_ids = eos_ids | generation_eos_ids
|
eos_ids = eos_ids | generation_eos_ids
|
||||||
return eos_ids
|
return eos_ids
|
||||||
|
|
||||||
|
def get_default_sampling_params(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get default sampling parameters from the model's generation config.
|
||||||
|
|
||||||
|
This method returns non-default sampling parameters from the model's
|
||||||
|
generation_config.json when sampling_defaults is set to "model".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing the non-default sampling parameters.
|
||||||
|
"""
|
||||||
|
if self.sampling_defaults != "model":
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if self.hf_generation_config is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
config = self.hf_generation_config.to_dict()
|
||||||
|
|
||||||
|
available_params = [
|
||||||
|
"repetition_penalty",
|
||||||
|
"temperature",
|
||||||
|
"top_k",
|
||||||
|
"top_p",
|
||||||
|
"min_p",
|
||||||
|
]
|
||||||
|
|
||||||
|
default_sampling_params = {
|
||||||
|
p: config.get(p) for p in available_params if config.get(p) is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
return default_sampling_params
|
||||||
|
|
||||||
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
|
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||||
"""
|
"""
|
||||||
Pull the model config files to a temporary
|
Pull the model config files to a temporary
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Pydantic models for OpenAI API protocol"""
|
"""Pydantic models for OpenAI API protocol"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -37,6 +38,10 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from sglang.utils import convert_json_schema_to_str
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME = "default"
|
DEFAULT_MODEL_NAME = "default"
|
||||||
|
|
||||||
|
|
||||||
@@ -445,8 +450,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
stream_options: Optional[StreamOptions] = None
|
stream_options: Optional[StreamOptions] = None
|
||||||
temperature: float = 0.7
|
temperature: Optional[float] = None
|
||||||
top_p: float = 1.0
|
top_p: Optional[float] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
||||||
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
||||||
@@ -461,6 +466,47 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
"Currently only supported for OpenAI models in the harmony path, i.e GPT-OSS models.",
|
"Currently only supported for OpenAI models in the harmony path, i.e GPT-OSS models.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
min_p: Optional[float] = None
|
||||||
|
min_tokens: int = 0
|
||||||
|
regex: Optional[str] = None
|
||||||
|
ebnf: Optional[str] = None
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
stop_token_ids: Optional[List[int]] = None
|
||||||
|
no_stop_trim: bool = False
|
||||||
|
ignore_eos: bool = False
|
||||||
|
continue_final_message: bool = False
|
||||||
|
skip_special_tokens: bool = True
|
||||||
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
session_params: Optional[Dict] = None
|
||||||
|
separate_reasoning: bool = True
|
||||||
|
stream_reasoning: bool = True
|
||||||
|
chat_template_kwargs: Optional[Dict] = None
|
||||||
|
|
||||||
|
# For request id
|
||||||
|
rid: Optional[Union[List[str], str]] = None
|
||||||
|
# Extra key for classifying the request (e.g. cache_salt)
|
||||||
|
extra_key: Optional[Union[List[str], str]] = None
|
||||||
|
# Cache salt for request caching
|
||||||
|
cache_salt: Optional[Union[List[str], str]] = None
|
||||||
|
# Priority for the request
|
||||||
|
priority: Optional[int] = None
|
||||||
|
|
||||||
|
# For PD disaggregation
|
||||||
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||||
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||||
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||||
|
|
||||||
|
# OpenAI/SGLang default sampling parameters
|
||||||
|
_DEFAULT_SAMPLING_PARAMS = {
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"min_p": 0.0,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_tool_choice_default(cls, values):
|
def set_tool_choice_default(cls, values):
|
||||||
@@ -531,37 +577,81 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
def to_sampling_params(
|
||||||
top_k: int = -1
|
self,
|
||||||
min_p: float = 0.0
|
stop: List[str],
|
||||||
min_tokens: int = 0
|
model_generation_config: Dict[str, Any],
|
||||||
regex: Optional[str] = None
|
tool_call_constraint: Optional[Any] = None,
|
||||||
ebnf: Optional[str] = None
|
) -> Dict[str, Any]:
|
||||||
repetition_penalty: float = 1.0
|
"""
|
||||||
stop_token_ids: Optional[List[int]] = None
|
Convert request to sampling parameters.
|
||||||
no_stop_trim: bool = False
|
Priority: user value > model generation_config > OpenAI defaults
|
||||||
ignore_eos: bool = False
|
"""
|
||||||
continue_final_message: bool = False
|
|
||||||
skip_special_tokens: bool = True
|
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
|
||||||
session_params: Optional[Dict] = None
|
|
||||||
separate_reasoning: bool = True
|
|
||||||
stream_reasoning: bool = True
|
|
||||||
chat_template_kwargs: Optional[Dict] = None
|
|
||||||
|
|
||||||
# For request id
|
def get_param(param_name: str):
|
||||||
rid: Optional[Union[List[str], str]] = None
|
value = getattr(self, param_name)
|
||||||
# Extra key for classifying the request (e.g. cache_salt)
|
if value is None:
|
||||||
extra_key: Optional[Union[List[str], str]] = None
|
return model_generation_config.get(
|
||||||
# Cache salt for request caching
|
param_name, self._DEFAULT_SAMPLING_PARAMS[param_name]
|
||||||
cache_salt: Optional[Union[List[str], str]] = None
|
)
|
||||||
# Priority for the request
|
return value
|
||||||
priority: Optional[int] = None
|
|
||||||
|
|
||||||
# For PD disaggregation
|
sampling_params = {
|
||||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
"temperature": get_param("temperature"),
|
||||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
"max_new_tokens": self.max_tokens or self.max_completion_tokens,
|
||||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
"min_new_tokens": self.min_tokens,
|
||||||
|
"stop": stop,
|
||||||
|
"stop_token_ids": self.stop_token_ids,
|
||||||
|
"top_p": get_param("top_p"),
|
||||||
|
"top_k": get_param("top_k"),
|
||||||
|
"min_p": get_param("min_p"),
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"repetition_penalty": get_param("repetition_penalty"),
|
||||||
|
"regex": self.regex,
|
||||||
|
"ebnf": self.ebnf,
|
||||||
|
"n": self.n,
|
||||||
|
"no_stop_trim": self.no_stop_trim,
|
||||||
|
"ignore_eos": self.ignore_eos,
|
||||||
|
"skip_special_tokens": self.skip_special_tokens,
|
||||||
|
"logit_bias": self.logit_bias,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.response_format and self.response_format.type == "json_schema":
|
||||||
|
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||||
|
self.response_format.json_schema.schema_
|
||||||
|
)
|
||||||
|
elif self.response_format and self.response_format.type == "json_object":
|
||||||
|
sampling_params["json_schema"] = '{"type": "object"}'
|
||||||
|
elif self.response_format and self.response_format.type == "structural_tag":
|
||||||
|
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||||
|
self.response_format.model_dump(by_alias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if there are already existing output constraints
|
||||||
|
has_existing_constraints = (
|
||||||
|
sampling_params.get("regex")
|
||||||
|
or sampling_params.get("ebnf")
|
||||||
|
or sampling_params.get("structural_tag")
|
||||||
|
or sampling_params.get("json_schema")
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_call_constraint and has_existing_constraints:
|
||||||
|
logger.warning("Constrained decoding is not compatible with tool calls.")
|
||||||
|
elif tool_call_constraint:
|
||||||
|
constraint_type, constraint_value = tool_call_constraint
|
||||||
|
if constraint_type == "structural_tag":
|
||||||
|
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||||
|
constraint_value.model_dump(by_alias=True)
|
||||||
|
)
|
||||||
|
elif constraint_type == "json_schema":
|
||||||
|
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||||
|
constraint_value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampling_params[constraint_type] = constraint_value
|
||||||
|
|
||||||
|
return sampling_params
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ from sglang.srt.managers.io_struct import GenerateReqInput
|
|||||||
from sglang.srt.parser.conversation import generate_chat_conv
|
from sglang.srt.parser.conversation import generate_chat_conv
|
||||||
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
||||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||||
from sglang.utils import convert_json_schema_to_str
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.template_manager import TemplateManager
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
@@ -66,6 +65,15 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||||
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||||
|
|
||||||
|
# Get default sampling parameters from model's generation config
|
||||||
|
self.default_sampling_params = (
|
||||||
|
self.tokenizer_manager.model_config.get_default_sampling_params()
|
||||||
|
)
|
||||||
|
if self.default_sampling_params:
|
||||||
|
logger.info(
|
||||||
|
f"Using default chat sampling params from model generation config: {self.default_sampling_params}",
|
||||||
|
)
|
||||||
|
|
||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "chatcmpl-"
|
return "chatcmpl-"
|
||||||
|
|
||||||
@@ -137,10 +145,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
processed_messages = self._process_messages(request, is_multimodal)
|
processed_messages = self._process_messages(request, is_multimodal)
|
||||||
|
|
||||||
# Build sampling parameters
|
# Build sampling parameters
|
||||||
sampling_params = self._build_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
request,
|
stop=processed_messages.stop,
|
||||||
processed_messages.stop,
|
model_generation_config=self.default_sampling_params,
|
||||||
processed_messages.tool_call_constraint,
|
tool_call_constraint=processed_messages.tool_call_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle single vs multiple requests
|
# Handle single vs multiple requests
|
||||||
@@ -410,72 +418,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_sampling_params(
|
|
||||||
self,
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
stop: List[str],
|
|
||||||
tool_call_constraint: Optional[Any],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Build sampling parameters for the request"""
|
|
||||||
|
|
||||||
sampling_params = {
|
|
||||||
"temperature": request.temperature,
|
|
||||||
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
|
|
||||||
"min_new_tokens": request.min_tokens,
|
|
||||||
"stop": stop,
|
|
||||||
"stop_token_ids": request.stop_token_ids,
|
|
||||||
"top_p": request.top_p,
|
|
||||||
"top_k": request.top_k,
|
|
||||||
"min_p": request.min_p,
|
|
||||||
"presence_penalty": request.presence_penalty,
|
|
||||||
"frequency_penalty": request.frequency_penalty,
|
|
||||||
"repetition_penalty": request.repetition_penalty,
|
|
||||||
"regex": request.regex,
|
|
||||||
"ebnf": request.ebnf,
|
|
||||||
"n": request.n,
|
|
||||||
"no_stop_trim": request.no_stop_trim,
|
|
||||||
"ignore_eos": request.ignore_eos,
|
|
||||||
"skip_special_tokens": request.skip_special_tokens,
|
|
||||||
"logit_bias": request.logit_bias,
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.response_format and request.response_format.type == "json_schema":
|
|
||||||
sampling_params["json_schema"] = convert_json_schema_to_str(
|
|
||||||
request.response_format.json_schema.schema_
|
|
||||||
)
|
|
||||||
elif request.response_format and request.response_format.type == "json_object":
|
|
||||||
sampling_params["json_schema"] = '{"type": "object"}'
|
|
||||||
elif (
|
|
||||||
request.response_format and request.response_format.type == "structural_tag"
|
|
||||||
):
|
|
||||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
|
||||||
request.response_format.model_dump(by_alias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if there are already existing output constraints
|
|
||||||
has_existing_constraints = (
|
|
||||||
sampling_params.get("regex")
|
|
||||||
or sampling_params.get("ebnf")
|
|
||||||
or sampling_params.get("structural_tag")
|
|
||||||
or sampling_params.get("json_schema")
|
|
||||||
)
|
|
||||||
|
|
||||||
if tool_call_constraint and has_existing_constraints:
|
|
||||||
logger.warning("Constrained decoding is not compatible with tool calls.")
|
|
||||||
elif tool_call_constraint:
|
|
||||||
constraint_type, constraint_value = tool_call_constraint
|
|
||||||
if constraint_type == "structural_tag":
|
|
||||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
|
||||||
constraint_value.model_dump(by_alias=True)
|
|
||||||
)
|
|
||||||
elif constraint_type == "json_schema":
|
|
||||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
|
||||||
constraint_value
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sampling_params[constraint_type] = constraint_value
|
|
||||||
return sampling_params
|
|
||||||
|
|
||||||
async def _handle_streaming_request(
|
async def _handle_streaming_request(
|
||||||
self,
|
self,
|
||||||
adapted_request: GenerateReqInput,
|
adapted_request: GenerateReqInput,
|
||||||
|
|||||||
@@ -252,6 +252,7 @@ class ServerArgs:
|
|||||||
reasoning_parser: Optional[str] = None
|
reasoning_parser: Optional[str] = None
|
||||||
tool_call_parser: Optional[str] = None
|
tool_call_parser: Optional[str] = None
|
||||||
tool_server: Optional[str] = None
|
tool_server: Optional[str] = None
|
||||||
|
sampling_defaults: str = "model"
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
@@ -1872,6 +1873,16 @@ class ServerArgs:
|
|||||||
default=ServerArgs.tool_call_parser,
|
default=ServerArgs.tool_call_parser,
|
||||||
help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.",
|
help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-defaults",
|
||||||
|
type=str,
|
||||||
|
choices=["openai", "model"],
|
||||||
|
default=ServerArgs.sampling_defaults,
|
||||||
|
help="Where to get default sampling parameters. "
|
||||||
|
"'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). "
|
||||||
|
"'model' uses the model's generation_config.json to get the recommended "
|
||||||
|
"sampling parameters if available. Default is 'model'.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tool-server",
|
"--tool-server",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -150,10 +150,26 @@ class TestChatCompletionRequest(unittest.TestCase):
|
|||||||
self.assertEqual(len(request.messages), 1)
|
self.assertEqual(len(request.messages), 1)
|
||||||
self.assertEqual(request.messages[0].role, "user")
|
self.assertEqual(request.messages[0].role, "user")
|
||||||
self.assertEqual(request.messages[0].content, "Hello")
|
self.assertEqual(request.messages[0].content, "Hello")
|
||||||
self.assertEqual(request.temperature, 0.7) # default
|
self.assertEqual(request.temperature, None) # default
|
||||||
self.assertFalse(request.stream) # default
|
self.assertFalse(request.stream) # default
|
||||||
self.assertEqual(request.tool_choice, "none") # default when no tools
|
self.assertEqual(request.tool_choice, "none") # default when no tools
|
||||||
|
|
||||||
|
def test_sampling_param_build(self):
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=150,
|
||||||
|
min_tokens=5,
|
||||||
|
top_p=0.9,
|
||||||
|
stop=["</s>"],
|
||||||
|
)
|
||||||
|
params = req.to_sampling_params(["</s>"], {}, None)
|
||||||
|
self.assertEqual(params["temperature"], 0.8)
|
||||||
|
self.assertEqual(params["max_new_tokens"], 150)
|
||||||
|
self.assertEqual(params["min_new_tokens"], 5)
|
||||||
|
self.assertEqual(params["stop"], ["</s>"])
|
||||||
|
|
||||||
def test_chat_completion_tool_choice_validation(self):
|
def test_chat_completion_tool_choice_validation(self):
|
||||||
"""Test tool choice validation logic"""
|
"""Test tool choice validation logic"""
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|||||||
@@ -177,28 +177,6 @@ class ServingChatTestCase(unittest.TestCase):
|
|||||||
self.assertNotIn("CUSTOM_STOP", result2.stop)
|
self.assertNotIn("CUSTOM_STOP", result2.stop)
|
||||||
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||||
|
|
||||||
# ------------- sampling-params -------------
|
|
||||||
def test_sampling_param_build(self):
|
|
||||||
req = ChatCompletionRequest(
|
|
||||||
model="x",
|
|
||||||
messages=[{"role": "user", "content": "Hi"}],
|
|
||||||
temperature=0.8,
|
|
||||||
max_tokens=150,
|
|
||||||
min_tokens=5,
|
|
||||||
top_p=0.9,
|
|
||||||
stop=["</s>"],
|
|
||||||
)
|
|
||||||
with patch.object(
|
|
||||||
self.chat,
|
|
||||||
"_process_messages",
|
|
||||||
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
|
|
||||||
):
|
|
||||||
params = self.chat._build_sampling_params(req, ["</s>"], None)
|
|
||||||
self.assertEqual(params["temperature"], 0.8)
|
|
||||||
self.assertEqual(params["max_new_tokens"], 150)
|
|
||||||
self.assertEqual(params["min_new_tokens"], 5)
|
|
||||||
self.assertEqual(params["stop"], ["</s>"])
|
|
||||||
|
|
||||||
async def test_unstreamed_tool_args_completion(self):
|
async def test_unstreamed_tool_args_completion(self):
|
||||||
"""Test that remaining tool call arguments are sent when generation finishes."""
|
"""Test that remaining tool call arguments are sent when generation finishes."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user