Fix openai speculative execution (#456)

This commit is contained in:
Ying Sheng
2024-05-20 17:01:13 -07:00
committed by GitHub
parent ec380dfd30
commit 3e684be7a3
7 changed files with 243 additions and 128 deletions

View File

@@ -196,12 +196,6 @@ class StreamExecutor:
# For completion
self.text_ = "" # The full text
# For speculative execution
from sglang.backend.openai import OpenAI
if isinstance(backend, OpenAI):
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
self.speculated_text = ""
# For chat
self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template()
@@ -215,6 +209,10 @@ class StreamExecutor:
# For fork/join
self.fork_start_text_pos = None
# For speculative execution
self.api_num_spec_tokens = api_num_spec_tokens
self.speculated_text = ""
# Worker thread
self.use_thread = use_thread
if self.use_thread:
@@ -293,6 +291,8 @@ class StreamExecutor:
exes[i].fork_start_text_pos = len(self.text_)
exes[i].images_ = list(self.images_)
# TODO(ying): handle API speculative execution
return exes
def text(self):
@@ -399,7 +399,7 @@ class StreamExecutor:
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.api_num_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
@@ -435,71 +435,80 @@ class StreamExecutor:
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
def _spec_gen(self, sampling_params):
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
nonlocal meta_info
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
return comp, meta_info
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
if self.backend.api_num_spec_tokens is None:
if self.api_num_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
elif self.backend.is_chat_model:
# spec on model with only chat interface
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
name=name,
)
return
else: # spec on model with completion
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
else:
if self.backend.is_chat_model:
# Speculative execution on models with only chat interface.
# Store the calls into a temporary list.
# They will be lazily executed later.
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
spec_var_name=name,
)
return
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)
self.text_ += comp
@@ -508,7 +517,7 @@ class StreamExecutor:
self.variable_event[name].set()
else:
assert (
self.backend.api_num_spec_tokens is None
self.api_num_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
@@ -571,9 +580,10 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
and self.api_num_spec_tokens is not None
):
# Execute the stored lazy generation calls
self.backend.role_end_generate(self)
self.cur_role = None