openai chat speculative execution (#250)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
LiviaSun
2024-05-18 22:23:53 -07:00
committed by GitHub
parent 5b647543c1
commit ec380dfd30
11 changed files with 316 additions and 45 deletions

View File

@@ -6,6 +6,7 @@ import multiprocessing
import queue
import threading
import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
@@ -185,7 +186,6 @@ class StreamExecutor:
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
self.api_num_spec_tokens = api_num_spec_tokens
self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event]
@@ -197,6 +197,9 @@ class StreamExecutor:
self.text_ = "" # The full text
# For speculative execution
from sglang.backend.openai import OpenAI
if isinstance(backend, OpenAI):
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
self.speculated_text = ""
# For chat
@@ -322,7 +325,7 @@ class StreamExecutor:
try:
self._execute(expr)
except Exception as e:
# print(f"Error in stream_executor: {get_exception_traceback()}")
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
error = e
break
self.queue.task_done()
@@ -391,12 +394,23 @@ class StreamExecutor:
else:
raise ValueError(f"Unknown type: {type(other)}")
def _execute_fill(self, value: str):
def _execute_fill(self, value: str, prefix=False):
value = str(value)
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
self.backend.spec_fill(value)
return
if self.speculated_text.startswith(value):
self.speculated_text = self.speculated_text[len(value) :]
else:
self.speculated_text = ""
self.text_ += value
def _execute_image(self, expr: SglImage):
@@ -426,14 +440,29 @@ class StreamExecutor:
name = expr.name
if not self.stream:
if self.api_num_spec_tokens is not None:
if self.backend.api_num_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
elif self.backend.is_chat_model:
# spec on model with only chat interface
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
name=name,
)
return
else: # spec on model with completion
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
@@ -442,16 +471,14 @@ class StreamExecutor:
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop), len(stop)
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
stop_len = 0
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
stop_len = len(stop_str)
return pos, stop_len
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
@@ -463,23 +490,16 @@ class StreamExecutor:
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos, stop_len = find_stop()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos, stop_len = (
min(
sampling_params.max_new_tokens,
len(self.speculated_text),
),
0,
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
else:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
self.text_ += comp
@@ -487,6 +507,9 @@ class StreamExecutor:
self.meta_info[name] = meta_info
self.variable_event[name].set()
else:
assert (
self.backend.api_num_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
)
@@ -542,10 +565,18 @@ class StreamExecutor:
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(prefix)
self._execute_fill(prefix, prefix=True)
self.cur_role_begin_pos = len(self.text_)
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
):
self.backend.role_end_generate(self)
self.cur_role = None
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
@@ -572,8 +603,6 @@ class StreamExecutor:
# OpenAI chat API format
self.messages_.append({"role": expr.role, "content": new_text})
self.cur_role = None
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
self.variables[expr.name] = int(len(self.text_))