default sampling param should be deepcopied (#1581)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import copy
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
@@ -652,7 +653,19 @@ class StreamExecutor:
|
||||
self._init_var_event(e)
|
||||
|
||||
def _resolve_sampling_params(self, sampling_params):
|
||||
clone = None
|
||||
"""
|
||||
Construct sampling param based on default + override values
|
||||
|
||||
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
|
||||
, and `sampling_params` contains the override values from sgl.gen().
|
||||
|
||||
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
|
||||
It also extends the stop tokens based on the chat template.
|
||||
"""
|
||||
|
||||
# deepcopy is required because the dict has lists inside
|
||||
clone = copy.deepcopy(self.default_sampling_para)
|
||||
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"stop",
|
||||
@@ -674,20 +687,16 @@ class StreamExecutor:
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
if clone is None:
|
||||
clone = self.default_sampling_para.clone()
|
||||
setattr(clone, item, value)
|
||||
|
||||
if self.chat_template.stop_str:
|
||||
if not clone:
|
||||
clone = self.default_sampling_para.clone()
|
||||
if clone.stop == ():
|
||||
clone.stop = []
|
||||
elif isinstance(clone.stop, str):
|
||||
clone.stop = [clone.stop]
|
||||
clone.stop += self.chat_template.stop_str
|
||||
|
||||
return clone or self.default_sampling_para
|
||||
return clone
|
||||
|
||||
def __del__(self):
|
||||
self.end()
|
||||
|
||||
@@ -150,8 +150,8 @@ class SglFunction:
|
||||
self,
|
||||
*args,
|
||||
max_new_tokens: int = 128,
|
||||
stop: Union[str, List[str]] = [],
|
||||
stop_token_ids: Optional[List[int]] = [],
|
||||
stop: Union[str, List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -169,6 +169,12 @@ class SglFunction:
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
|
||||
if stop is None:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
@@ -193,8 +199,8 @@ class SglFunction:
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
stop: Union[str, List[str]] = (),
|
||||
stop_token_ids: Optional[List[int]] = [],
|
||||
stop: Union[str, List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -212,6 +218,11 @@ class SglFunction:
|
||||
):
|
||||
from sglang.lang.interpreter import run_program_batch
|
||||
|
||||
if stop is None:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
|
||||
@@ -26,7 +26,7 @@ class SamplingParams:
|
||||
max_new_tokens: int = 128,
|
||||
min_new_tokens: int = 0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = [],
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -41,6 +41,8 @@ class SamplingParams:
|
||||
n: int = 1,
|
||||
json_schema: Optional[str] = None,
|
||||
) -> None:
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
|
||||
@@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None
|
||||
|
||||
|
||||
def call_generate_outlines(
|
||||
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
|
||||
prompt, temperature, max_tokens, stop=None, regex=None, n=1, url=None
|
||||
):
|
||||
assert url is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user