support speculative execution for openai API (#48)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
parasol-aser
2024-01-25 03:57:06 -06:00
committed by GitHub
parent 93414c8238
commit 23950056f0
10 changed files with 178 additions and 12 deletions

View File

@@ -51,10 +51,14 @@ def run_program(
if hasattr(backend, "endpoint"):
backend = backend.endpoint
assert backend is not None, "Please specify a backend"
func_kwargs.update(program.bind_arguments)
stream_executor = StreamExecutor(
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
backend,
func_kwargs,
default_sampling_para,
chat_template=None,
stream=stream,
api_num_spec_tokens=program.api_num_spec_tokens,
)
state = ProgramState(stream_executor)
@@ -175,6 +179,7 @@ class StreamExecutor:
default_sampling_para,
chat_template,
stream,
api_num_spec_tokens=None,
use_thread=True,
):
self.sid = uuid.uuid4().hex
@@ -182,6 +187,7 @@ class StreamExecutor:
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
self.api_num_spec_tokens = api_num_spec_tokens
self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event]
@@ -191,6 +197,9 @@ class StreamExecutor:
# For completion
self.text_ = "" # The full text
# For speculative execution
self.speculated_text = ""
# For chat
self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template()
@@ -341,6 +350,10 @@ class StreamExecutor:
def _execute_fill(self, value: str):
value = str(value)
if self.speculated_text.startswith(value):
self.speculated_text = self.speculated_text[len(value) :]
else:
self.speculated_text = ""
self.text_ += value
def _execute_image(self, expr: SglImage):
@@ -360,9 +373,61 @@ class StreamExecutor:
name = expr.name
if not self.stream:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
if self.api_num_spec_tokens is not None:
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop), len(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
stop_len = 0
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
stop_len = len(stop_str)
return pos, stop_len
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos, stop_len = find_stop()
if stop_pos == -1:
stop_pos, stop_len = (
min(
sampling_params.max_new_tokens,
len(self.speculated_text),
),
0,
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
else:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
self.text_ += comp
self.variables[name] = comp

View File

@@ -95,8 +95,9 @@ class SglSamplingParams:
class SglFunction:
def __init__(self, func, bind_arguments=None):
def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
self.func = func
self.api_num_spec_tokens = api_num_spec_tokens
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None