default sampling param should be deepcopied (#1581)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import copy
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
@@ -652,7 +653,19 @@ class StreamExecutor:
|
||||
self._init_var_event(e)
|
||||
|
||||
def _resolve_sampling_params(self, sampling_params):
|
||||
clone = None
|
||||
"""
|
||||
Construct sampling param based on default + override values
|
||||
|
||||
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
|
||||
, and `sampling_params` contains the override values from sgl.gen().
|
||||
|
||||
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
|
||||
It also extends the stop tokens based on the chat template.
|
||||
"""
|
||||
|
||||
# deepcopy is required because the dict has lists inside
|
||||
clone = copy.deepcopy(self.default_sampling_para)
|
||||
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"stop",
|
||||
@@ -674,20 +687,16 @@ class StreamExecutor:
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
if clone is None:
|
||||
clone = self.default_sampling_para.clone()
|
||||
setattr(clone, item, value)
|
||||
|
||||
if self.chat_template.stop_str:
|
||||
if not clone:
|
||||
clone = self.default_sampling_para.clone()
|
||||
if clone.stop == ():
|
||||
clone.stop = []
|
||||
elif isinstance(clone.stop, str):
|
||||
clone.stop = [clone.stop]
|
||||
clone.stop += self.chat_template.stop_str
|
||||
|
||||
return clone or self.default_sampling_para
|
||||
return clone
|
||||
|
||||
def __del__(self):
|
||||
self.end()
|
||||
|
||||
Reference in New Issue
Block a user