openai chat speculative execution (#250)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
@@ -185,7 +186,6 @@ class StreamExecutor:
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.default_sampling_para = default_sampling_para
|
||||
self.stream = stream
|
||||
self.api_num_spec_tokens = api_num_spec_tokens
|
||||
|
||||
self.variables = {} # Dict[name: str -> value: str]
|
||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||
@@ -197,6 +197,9 @@ class StreamExecutor:
|
||||
self.text_ = "" # The full text
|
||||
|
||||
# For speculative execution
|
||||
from sglang.backend.openai import OpenAI
|
||||
if isinstance(backend, OpenAI):
|
||||
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
|
||||
self.speculated_text = ""
|
||||
|
||||
# For chat
|
||||
@@ -322,7 +325,7 @@ class StreamExecutor:
|
||||
try:
|
||||
self._execute(expr)
|
||||
except Exception as e:
|
||||
# print(f"Error in stream_executor: {get_exception_traceback()}")
|
||||
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
|
||||
error = e
|
||||
break
|
||||
self.queue.task_done()
|
||||
@@ -391,12 +394,23 @@ class StreamExecutor:
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type(other)}")
|
||||
|
||||
def _execute_fill(self, value: str):
|
||||
def _execute_fill(self, value: str, prefix=False):
|
||||
value = str(value)
|
||||
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.backend.api_num_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
and not prefix
|
||||
):
|
||||
self.backend.spec_fill(value)
|
||||
return
|
||||
|
||||
if self.speculated_text.startswith(value):
|
||||
self.speculated_text = self.speculated_text[len(value) :]
|
||||
else:
|
||||
self.speculated_text = ""
|
||||
|
||||
self.text_ += value
|
||||
|
||||
def _execute_image(self, expr: SglImage):
|
||||
@@ -426,14 +440,29 @@ class StreamExecutor:
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
if self.api_num_spec_tokens is not None:
|
||||
if self.backend.api_num_spec_tokens is None:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
elif self.backend.is_chat_model:
|
||||
# spec on model with only chat interface
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
name=name,
|
||||
)
|
||||
return
|
||||
|
||||
else: # spec on model with completion
|
||||
stop = sampling_params.stop
|
||||
max_new_tokens = sampling_params.max_new_tokens
|
||||
meta_info = {}
|
||||
|
||||
def regen():
|
||||
sampling_params.max_new_tokens = max(
|
||||
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
||||
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
|
||||
)
|
||||
sampling_params.stop = None
|
||||
self.speculated_text, meta_info = self.backend.generate(
|
||||
@@ -442,16 +471,14 @@ class StreamExecutor:
|
||||
|
||||
def find_stop():
|
||||
if isinstance(stop, str):
|
||||
return self.speculated_text.find(stop), len(stop)
|
||||
return self.speculated_text.find(stop)
|
||||
elif isinstance(stop, (tuple, list)):
|
||||
pos = -1
|
||||
stop_len = 0
|
||||
for stop_str in stop:
|
||||
stop_pos = self.speculated_text.find(stop_str)
|
||||
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||
pos = stop_pos
|
||||
stop_len = len(stop_str)
|
||||
return pos, stop_len
|
||||
return pos
|
||||
else:
|
||||
raise Exception("Wrong type of stop in sampling parameters.")
|
||||
|
||||
@@ -463,23 +490,16 @@ class StreamExecutor:
|
||||
elif isinstance(stop, (str, list, tuple)):
|
||||
if self.speculated_text == "":
|
||||
regen()
|
||||
stop_pos, stop_len = find_stop()
|
||||
stop_pos = find_stop()
|
||||
if stop_pos == -1:
|
||||
stop_pos, stop_len = (
|
||||
min(
|
||||
sampling_params.max_new_tokens,
|
||||
len(self.speculated_text),
|
||||
),
|
||||
0,
|
||||
stop_pos = min(
|
||||
sampling_params.max_new_tokens,
|
||||
len(self.speculated_text),
|
||||
)
|
||||
comp = self.speculated_text[:stop_pos]
|
||||
self.speculated_text = self.speculated_text[stop_pos:]
|
||||
else:
|
||||
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||
else:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
self.text_ += comp
|
||||
|
||||
@@ -487,6 +507,9 @@ class StreamExecutor:
|
||||
self.meta_info[name] = meta_info
|
||||
self.variable_event[name].set()
|
||||
else:
|
||||
assert (
|
||||
self.backend.api_num_spec_tokens is None
|
||||
), "stream is not supported with api speculative execution"
|
||||
generator = self.backend.generate_stream(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
@@ -542,10 +565,18 @@ class StreamExecutor:
|
||||
|
||||
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
self._execute_fill(prefix, prefix=True)
|
||||
self.cur_role_begin_pos = len(self.text_)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.backend.api_num_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
):
|
||||
self.backend.role_end_generate(self)
|
||||
self.cur_role = None
|
||||
|
||||
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
||||
|
||||
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
@@ -572,8 +603,6 @@ class StreamExecutor:
|
||||
# OpenAI chat API format
|
||||
self.messages_.append({"role": expr.role, "content": new_text})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
||||
self.variables[expr.name] = int(len(self.text_))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user