Fix openai speculative execution (#456)
This commit is contained in:
@@ -196,12 +196,6 @@ class StreamExecutor:
|
||||
# For completion
|
||||
self.text_ = "" # The full text
|
||||
|
||||
# For speculative execution
|
||||
from sglang.backend.openai import OpenAI
|
||||
if isinstance(backend, OpenAI):
|
||||
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
|
||||
self.speculated_text = ""
|
||||
|
||||
# For chat
|
||||
self.messages_ = [] # The messages in the OpenAI API format
|
||||
self.chat_template = chat_template or self.backend.get_chat_template()
|
||||
@@ -215,6 +209,10 @@ class StreamExecutor:
|
||||
# For fork/join
|
||||
self.fork_start_text_pos = None
|
||||
|
||||
# For speculative execution
|
||||
self.api_num_spec_tokens = api_num_spec_tokens
|
||||
self.speculated_text = ""
|
||||
|
||||
# Worker thread
|
||||
self.use_thread = use_thread
|
||||
if self.use_thread:
|
||||
@@ -293,6 +291,8 @@ class StreamExecutor:
|
||||
exes[i].fork_start_text_pos = len(self.text_)
|
||||
exes[i].images_ = list(self.images_)
|
||||
|
||||
# TODO(ying): handle API speculative execution
|
||||
|
||||
return exes
|
||||
|
||||
def text(self):
|
||||
@@ -399,7 +399,7 @@ class StreamExecutor:
|
||||
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.backend.api_num_spec_tokens is not None
|
||||
and self.api_num_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
and not prefix
|
||||
):
|
||||
@@ -435,71 +435,80 @@ class StreamExecutor:
|
||||
# if global_config.eager_fill_image:
|
||||
# self.backend.fill_image(self)
|
||||
|
||||
def _spec_gen(self, sampling_params):
|
||||
stop = sampling_params.stop
|
||||
max_new_tokens = sampling_params.max_new_tokens
|
||||
meta_info = {}
|
||||
|
||||
def regen():
|
||||
nonlocal meta_info
|
||||
|
||||
sampling_params.max_new_tokens = max(
|
||||
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
||||
)
|
||||
sampling_params.stop = None
|
||||
self.speculated_text, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
def find_stop():
|
||||
if isinstance(stop, str):
|
||||
return self.speculated_text.find(stop)
|
||||
elif isinstance(stop, (tuple, list)):
|
||||
pos = -1
|
||||
for stop_str in stop:
|
||||
stop_pos = self.speculated_text.find(stop_str)
|
||||
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||
pos = stop_pos
|
||||
return pos
|
||||
else:
|
||||
raise Exception("Wrong type of stop in sampling parameters.")
|
||||
|
||||
if stop is None:
|
||||
if len(self.speculated_text) < max_new_tokens:
|
||||
regen()
|
||||
comp = self.speculated_text[:max_new_tokens]
|
||||
self.speculated_text = self.speculated_text[max_new_tokens:]
|
||||
elif isinstance(stop, (str, list, tuple)):
|
||||
if self.speculated_text == "":
|
||||
regen()
|
||||
stop_pos = find_stop()
|
||||
if stop_pos == -1:
|
||||
stop_pos = min(
|
||||
sampling_params.max_new_tokens,
|
||||
len(self.speculated_text),
|
||||
)
|
||||
comp = self.speculated_text[:stop_pos]
|
||||
self.speculated_text = self.speculated_text[stop_pos:]
|
||||
else:
|
||||
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||
|
||||
return comp, meta_info
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
if self.backend.api_num_spec_tokens is None:
|
||||
if self.api_num_spec_tokens is None:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
elif self.backend.is_chat_model:
|
||||
# spec on model with only chat interface
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
name=name,
|
||||
)
|
||||
return
|
||||
|
||||
else: # spec on model with completion
|
||||
stop = sampling_params.stop
|
||||
max_new_tokens = sampling_params.max_new_tokens
|
||||
meta_info = {}
|
||||
|
||||
def regen():
|
||||
sampling_params.max_new_tokens = max(
|
||||
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
|
||||
)
|
||||
sampling_params.stop = None
|
||||
self.speculated_text, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
else:
|
||||
if self.backend.is_chat_model:
|
||||
# Speculative execution on models with only chat interface.
|
||||
# Store the calls into a temporary list.
|
||||
# They will be lazily executed later.
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
spec_var_name=name,
|
||||
)
|
||||
return
|
||||
|
||||
def find_stop():
|
||||
if isinstance(stop, str):
|
||||
return self.speculated_text.find(stop)
|
||||
elif isinstance(stop, (tuple, list)):
|
||||
pos = -1
|
||||
for stop_str in stop:
|
||||
stop_pos = self.speculated_text.find(stop_str)
|
||||
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||
pos = stop_pos
|
||||
return pos
|
||||
else:
|
||||
raise Exception("Wrong type of stop in sampling parameters.")
|
||||
|
||||
if stop is None:
|
||||
if len(self.speculated_text) < max_new_tokens:
|
||||
regen()
|
||||
comp = self.speculated_text[:max_new_tokens]
|
||||
self.speculated_text = self.speculated_text[max_new_tokens:]
|
||||
elif isinstance(stop, (str, list, tuple)):
|
||||
if self.speculated_text == "":
|
||||
regen()
|
||||
stop_pos = find_stop()
|
||||
if stop_pos == -1:
|
||||
stop_pos = min(
|
||||
sampling_params.max_new_tokens,
|
||||
len(self.speculated_text),
|
||||
)
|
||||
comp = self.speculated_text[:stop_pos]
|
||||
self.speculated_text = self.speculated_text[stop_pos:]
|
||||
else:
|
||||
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||
else: # Speculative execution on models with completion interface
|
||||
comp, meta_info = self._spec_gen(sampling_params)
|
||||
|
||||
self.text_ += comp
|
||||
|
||||
@@ -508,7 +517,7 @@ class StreamExecutor:
|
||||
self.variable_event[name].set()
|
||||
else:
|
||||
assert (
|
||||
self.backend.api_num_spec_tokens is None
|
||||
self.api_num_spec_tokens is None
|
||||
), "stream is not supported with api speculative execution"
|
||||
generator = self.backend.generate_stream(
|
||||
self, sampling_params=sampling_params
|
||||
@@ -571,9 +580,10 @@ class StreamExecutor:
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.backend.api_num_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
and self.api_num_spec_tokens is not None
|
||||
):
|
||||
# Execute the stored lazy generation calls
|
||||
self.backend.role_end_generate(self)
|
||||
self.cur_role = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user