default sampling param should be deepcopied (#1581)
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
|
import copy
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
@@ -652,7 +653,19 @@ class StreamExecutor:
|
|||||||
self._init_var_event(e)
|
self._init_var_event(e)
|
||||||
|
|
||||||
def _resolve_sampling_params(self, sampling_params):
|
def _resolve_sampling_params(self, sampling_params):
|
||||||
clone = None
|
"""
|
||||||
|
Construct sampling param based on default + override values
|
||||||
|
|
||||||
|
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
|
||||||
|
, and `sampling_params` contains the override values from sgl.gen().
|
||||||
|
|
||||||
|
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
|
||||||
|
It also extends the stop tokens based on the chat template.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# deepcopy is required because the dict has lists inside
|
||||||
|
clone = copy.deepcopy(self.default_sampling_para)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
"max_new_tokens",
|
"max_new_tokens",
|
||||||
"stop",
|
"stop",
|
||||||
@@ -674,20 +687,16 @@ class StreamExecutor:
|
|||||||
]:
|
]:
|
||||||
value = getattr(sampling_params, item, None)
|
value = getattr(sampling_params, item, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if clone is None:
|
|
||||||
clone = self.default_sampling_para.clone()
|
|
||||||
setattr(clone, item, value)
|
setattr(clone, item, value)
|
||||||
|
|
||||||
if self.chat_template.stop_str:
|
if self.chat_template.stop_str:
|
||||||
if not clone:
|
|
||||||
clone = self.default_sampling_para.clone()
|
|
||||||
if clone.stop == ():
|
if clone.stop == ():
|
||||||
clone.stop = []
|
clone.stop = []
|
||||||
elif isinstance(clone.stop, str):
|
elif isinstance(clone.stop, str):
|
||||||
clone.stop = [clone.stop]
|
clone.stop = [clone.stop]
|
||||||
clone.stop += self.chat_template.stop_str
|
clone.stop += self.chat_template.stop_str
|
||||||
|
|
||||||
return clone or self.default_sampling_para
|
return clone
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.end()
|
self.end()
|
||||||
|
|||||||
@@ -150,8 +150,8 @@ class SglFunction:
|
|||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
stop: Union[str, List[str]] = [],
|
stop: Union[str, List[str]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = [],
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@@ -169,6 +169,12 @@ class SglFunction:
|
|||||||
):
|
):
|
||||||
from sglang.lang.interpreter import run_program
|
from sglang.lang.interpreter import run_program
|
||||||
|
|
||||||
|
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
|
||||||
|
if stop is None:
|
||||||
|
stop = []
|
||||||
|
if stop_token_ids is None:
|
||||||
|
stop_token_ids = []
|
||||||
|
|
||||||
default_sampling_para = SglSamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
@@ -193,8 +199,8 @@ class SglFunction:
|
|||||||
batch_kwargs,
|
batch_kwargs,
|
||||||
*,
|
*,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
stop: Union[str, List[str]] = (),
|
stop: Union[str, List[str]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = [],
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@@ -212,6 +218,11 @@ class SglFunction:
|
|||||||
):
|
):
|
||||||
from sglang.lang.interpreter import run_program_batch
|
from sglang.lang.interpreter import run_program_batch
|
||||||
|
|
||||||
|
if stop is None:
|
||||||
|
stop = []
|
||||||
|
if stop_token_ids is None:
|
||||||
|
stop_token_ids = []
|
||||||
|
|
||||||
assert isinstance(batch_kwargs, (list, tuple))
|
assert isinstance(batch_kwargs, (list, tuple))
|
||||||
if len(batch_kwargs) == 0:
|
if len(batch_kwargs) == 0:
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class SamplingParams:
|
|||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
min_new_tokens: int = 0,
|
min_new_tokens: int = 0,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = [],
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@@ -41,6 +41,8 @@ class SamplingParams:
|
|||||||
n: int = 1,
|
n: int = 1,
|
||||||
json_schema: Optional[str] = None,
|
json_schema: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if stop_token_ids is None:
|
||||||
|
stop_token_ids = []
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None
|
|||||||
|
|
||||||
|
|
||||||
def call_generate_outlines(
|
def call_generate_outlines(
|
||||||
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
|
prompt, temperature, max_tokens, stop=None, regex=None, n=1, url=None
|
||||||
):
|
):
|
||||||
assert url is not None
|
assert url is not None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user