From dde8bb16fe9180bd1642bdb8d4f0aa283b120ee4 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sat, 5 Oct 2024 17:27:43 -0700 Subject: [PATCH] default sampling param should be deepcopied (#1581) --- python/sglang/lang/interpreter.py | 21 +++++++++++++------ python/sglang/lang/ir.py | 19 +++++++++++++---- python/sglang/srt/sampling/sampling_params.py | 4 +++- python/sglang/test/test_utils.py | 2 +- 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 2fc72c2db..31c39d76a 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -2,6 +2,7 @@ import asyncio import contextvars +import copy import multiprocessing import queue import threading @@ -652,7 +653,19 @@ class StreamExecutor: self._init_var_event(e) def _resolve_sampling_params(self, sampling_params): - clone = None + """ + Construct sampling param based on default + override values + + The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args) + , and `sampling_params` contains the override values from sgl.gen(). + + Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`. + It also extends the stop tokens based on the chat template. + """ + + # deepcopy is required because the dict has lists inside + clone = copy.deepcopy(self.default_sampling_para) + for item in [ "max_new_tokens", "stop", @@ -674,20 +687,16 @@ class StreamExecutor: ]: value = getattr(sampling_params, item, None) if value is not None: - if clone is None: - clone = self.default_sampling_para.clone() setattr(clone, item, value) if self.chat_template.stop_str: - if not clone: - clone = self.default_sampling_para.clone() if clone.stop == (): clone.stop = [] elif isinstance(clone.stop, str): clone.stop = [clone.stop] clone.stop += self.chat_template.stop_str - return clone or self.default_sampling_para + return clone def __del__(self): self.end() diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 99a3e8e68..2e81d4bcd 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -150,8 +150,8 @@ class SglFunction: self, *args, max_new_tokens: int = 128, - stop: Union[str, List[str]] = [], - stop_token_ids: Optional[List[int]] = [], + stop: Union[str, List[str]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -169,6 +169,12 @@ class SglFunction: ): from sglang.lang.interpreter import run_program + # avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/ + if stop is None: + stop = [] + if stop_token_ids is None: + stop_token_ids = [] + default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, @@ -193,8 +199,8 @@ class SglFunction: batch_kwargs, *, max_new_tokens: int = 128, - stop: Union[str, List[str]] = (), - stop_token_ids: Optional[List[int]] = [], + stop: Union[str, List[str]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -212,6 +218,11 @@ class SglFunction: ): from sglang.lang.interpreter import run_program_batch + if stop is None: + stop = [] + if stop_token_ids is None: + stop_token_ids = [] + assert isinstance(batch_kwargs, (list, tuple)) if len(batch_kwargs) == 0: return [] diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 8111757d8..2c251bac4 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -26,7 +26,7 @@ class SamplingParams: max_new_tokens: int = 128, min_new_tokens: int = 0, stop: Optional[Union[str, List[str]]] = None, - stop_token_ids: Optional[List[int]] = [], + stop_token_ids: Optional[List[int]] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -41,6 +41,8 @@ class SamplingParams: n: int = 1, json_schema: Optional[str] = None, ) -> None: + if stop_token_ids is None: + stop_token_ids = [] self.temperature = temperature self.top_p = top_p self.top_k = top_k diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 7d844c9bc..8fb20c6eb 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None def call_generate_outlines( - prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None + prompt, temperature, max_tokens, stop=None, regex=None, n=1, url=None ): assert url is not None