Fix openai speculative execution (#456)
This commit is contained in:
@@ -9,7 +9,6 @@ class BaseBackend:
|
||||
def __init__(self) -> None:
|
||||
self.support_concate_and_append = False
|
||||
self.chat_template = get_chat_template("default")
|
||||
self.api_num_spec_tokens = None
|
||||
|
||||
def get_model_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
import dataclasses
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -42,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TokenUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
def reset(self):
|
||||
self.prompt_tokens = self.completion_tokens = 0
|
||||
|
||||
|
||||
class OpenAI(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -83,66 +93,73 @@ class OpenAI(BaseBackend):
|
||||
|
||||
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||
|
||||
# Usage
|
||||
self.token_usage = TokenUsage(0, 0)
|
||||
|
||||
# API speculative execution
|
||||
# TODO(ying): This does not support multi-threading (run_batch)
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
self.spec_max_num_tries = 3
|
||||
self.api_num_spec_tokens = None
|
||||
|
||||
def set_api_num_spec_tokens(self, num):
|
||||
self.api_num_spec_tokens = num
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
|
||||
api_num_spec_tokens: int, spec_var_name: str):
|
||||
if "max_tokens" not in self.spec_kwargs:
|
||||
self.spec_kwargs["max_tokens"] = api_num_spec_tokens
|
||||
else:
|
||||
assert (
|
||||
self.spec_kwargs["max_tokens"] == api_num_spec_tokens
|
||||
)
|
||||
|
||||
params = sampling_params.to_openai_kwargs()
|
||||
for key, value in params.items():
|
||||
if key in ["stop"]:
|
||||
continue
|
||||
if key in ["max_tokens"]:
|
||||
warnings.warn(
|
||||
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
||||
)
|
||||
continue
|
||||
if key not in self.spec_kwargs:
|
||||
self.spec_kwargs[key] = value
|
||||
else:
|
||||
assert (
|
||||
value == self.spec_kwargs[key]
|
||||
), "sampling parameters should be consistent if turn on api speculative execution."
|
||||
self.spec_format.append(
|
||||
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
||||
)
|
||||
return "", {}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
name=None,
|
||||
spec_var_name: str = None,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if self.api_num_spec_tokens is None:
|
||||
if s.api_num_spec_tokens is None:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported if api speculative execution is off. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant."
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
||||
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
# collect assistant answer format
|
||||
if "max_tokens" not in self.spec_kwargs:
|
||||
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
|
||||
else:
|
||||
assert (
|
||||
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
|
||||
)
|
||||
params = sampling_params.to_openai_kwargs()
|
||||
for key, value in params.items():
|
||||
if key in ["stop"]:
|
||||
continue
|
||||
if key in ["max_tokens"]:
|
||||
warnings.warn(
|
||||
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
||||
)
|
||||
continue
|
||||
if key not in self.spec_kwargs:
|
||||
self.spec_kwargs[key] = value
|
||||
else:
|
||||
assert (
|
||||
value == self.spec_kwargs[key]
|
||||
), "sampling parameters should be consistent if turn on api speculative execution."
|
||||
self.spec_format.append(
|
||||
{"text": "", "stop": params["stop"], "name": name}
|
||||
)
|
||||
return "", {}
|
||||
return self._prepare_spec_execution(sampling_params,
|
||||
s.api_num_spec_tokens, spec_var_name)
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
@@ -156,6 +173,7 @@ class OpenAI(BaseBackend):
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_ + '"',
|
||||
@@ -171,6 +189,7 @@ class OpenAI(BaseBackend):
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_,
|
||||
@@ -211,14 +230,16 @@ class OpenAI(BaseBackend):
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
):
|
||||
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||
if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||
return
|
||||
|
||||
comp = ""
|
||||
if not all(x["name"] is None for x in self.spec_format):
|
||||
# TODO(ying): throw errors or warnings
|
||||
for i in range(self.spec_max_num_tries):
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.messages_,
|
||||
@@ -228,7 +249,6 @@ class OpenAI(BaseBackend):
|
||||
break
|
||||
|
||||
for term in self.spec_format:
|
||||
stop = term["stop"] if term["stop"] is not None else ""
|
||||
s.text_ += term["text"]
|
||||
name = term["name"]
|
||||
if name is not None:
|
||||
@@ -258,6 +278,7 @@ class OpenAI(BaseBackend):
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
generator = openai_completion_stream(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
@@ -303,6 +324,8 @@ class OpenAI(BaseBackend):
|
||||
)
|
||||
ret_str = ret.choices[0].text
|
||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
self.token_usage.completion_tokens= ret.usage.completion_tokens
|
||||
|
||||
# TODO:
|
||||
# 1. return logits as the scores
|
||||
@@ -332,7 +355,7 @@ class OpenAI(BaseBackend):
|
||||
return decision, scores, None, None
|
||||
|
||||
|
||||
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
||||
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
@@ -346,6 +369,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
||||
comp = [c.text for c in ret.choices]
|
||||
else:
|
||||
comp = ret.choices[0].text
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
break
|
||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||
@@ -359,16 +385,19 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
||||
return comp
|
||||
|
||||
|
||||
def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
||||
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
generator = client.chat.completions.create(
|
||||
messages=prompt, stream=True, **kwargs
|
||||
messages=prompt, stream=True, stream_options={"include_usage": True},
|
||||
**kwargs
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
continue
|
||||
try:
|
||||
content = ret.choices[0].delta.content
|
||||
except IndexError:
|
||||
@@ -376,11 +405,17 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
|
||||
yield content or "", {}
|
||||
else:
|
||||
generator = client.completions.create(
|
||||
prompt=prompt, stream=True, **kwargs
|
||||
prompt=prompt, stream=True, stream_options={"include_usage": True},
|
||||
**kwargs
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
continue
|
||||
content = ret.choices[0].text
|
||||
yield content or "", {}
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
break
|
||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||
|
||||
@@ -196,12 +196,6 @@ class StreamExecutor:
|
||||
# For completion
|
||||
self.text_ = "" # The full text
|
||||
|
||||
# For speculative execution
|
||||
from sglang.backend.openai import OpenAI
|
||||
if isinstance(backend, OpenAI):
|
||||
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
|
||||
self.speculated_text = ""
|
||||
|
||||
# For chat
|
||||
self.messages_ = [] # The messages in the OpenAI API format
|
||||
self.chat_template = chat_template or self.backend.get_chat_template()
|
||||
@@ -215,6 +209,10 @@ class StreamExecutor:
|
||||
# For fork/join
|
||||
self.fork_start_text_pos = None
|
||||
|
||||
# For speculative execution
|
||||
self.api_num_spec_tokens = api_num_spec_tokens
|
||||
self.speculated_text = ""
|
||||
|
||||
# Worker thread
|
||||
self.use_thread = use_thread
|
||||
if self.use_thread:
|
||||
@@ -293,6 +291,8 @@ class StreamExecutor:
|
||||
exes[i].fork_start_text_pos = len(self.text_)
|
||||
exes[i].images_ = list(self.images_)
|
||||
|
||||
# TODO(ying): handle API speculative execution
|
||||
|
||||
return exes
|
||||
|
||||
def text(self):
|
||||
@@ -399,7 +399,7 @@ class StreamExecutor:
|
||||
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.backend.api_num_spec_tokens is not None
|
||||
and self.api_num_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
and not prefix
|
||||
):
|
||||
@@ -435,71 +435,80 @@ class StreamExecutor:
|
||||
# if global_config.eager_fill_image:
|
||||
# self.backend.fill_image(self)
|
||||
|
||||
def _spec_gen(self, sampling_params):
|
||||
stop = sampling_params.stop
|
||||
max_new_tokens = sampling_params.max_new_tokens
|
||||
meta_info = {}
|
||||
|
||||
def regen():
|
||||
nonlocal meta_info
|
||||
|
||||
sampling_params.max_new_tokens = max(
|
||||
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
||||
)
|
||||
sampling_params.stop = None
|
||||
self.speculated_text, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
def find_stop():
|
||||
if isinstance(stop, str):
|
||||
return self.speculated_text.find(stop)
|
||||
elif isinstance(stop, (tuple, list)):
|
||||
pos = -1
|
||||
for stop_str in stop:
|
||||
stop_pos = self.speculated_text.find(stop_str)
|
||||
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||
pos = stop_pos
|
||||
return pos
|
||||
else:
|
||||
raise Exception("Wrong type of stop in sampling parameters.")
|
||||
|
||||
if stop is None:
|
||||
if len(self.speculated_text) < max_new_tokens:
|
||||
regen()
|
||||
comp = self.speculated_text[:max_new_tokens]
|
||||
self.speculated_text = self.speculated_text[max_new_tokens:]
|
||||
elif isinstance(stop, (str, list, tuple)):
|
||||
if self.speculated_text == "":
|
||||
regen()
|
||||
stop_pos = find_stop()
|
||||
if stop_pos == -1:
|
||||
stop_pos = min(
|
||||
sampling_params.max_new_tokens,
|
||||
len(self.speculated_text),
|
||||
)
|
||||
comp = self.speculated_text[:stop_pos]
|
||||
self.speculated_text = self.speculated_text[stop_pos:]
|
||||
else:
|
||||
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||
|
||||
return comp, meta_info
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
if self.backend.api_num_spec_tokens is None:
|
||||
if self.api_num_spec_tokens is None:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
elif self.backend.is_chat_model:
|
||||
# spec on model with only chat interface
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
name=name,
|
||||
)
|
||||
return
|
||||
|
||||
else: # spec on model with completion
|
||||
stop = sampling_params.stop
|
||||
max_new_tokens = sampling_params.max_new_tokens
|
||||
meta_info = {}
|
||||
|
||||
def regen():
|
||||
sampling_params.max_new_tokens = max(
|
||||
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
|
||||
)
|
||||
sampling_params.stop = None
|
||||
self.speculated_text, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
else:
|
||||
if self.backend.is_chat_model:
|
||||
# Speculative execution on models with only chat interface.
|
||||
# Store the calls into a temporary list.
|
||||
# They will be lazily executed later.
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
spec_var_name=name,
|
||||
)
|
||||
return
|
||||
|
||||
def find_stop():
|
||||
if isinstance(stop, str):
|
||||
return self.speculated_text.find(stop)
|
||||
elif isinstance(stop, (tuple, list)):
|
||||
pos = -1
|
||||
for stop_str in stop:
|
||||
stop_pos = self.speculated_text.find(stop_str)
|
||||
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||
pos = stop_pos
|
||||
return pos
|
||||
else:
|
||||
raise Exception("Wrong type of stop in sampling parameters.")
|
||||
|
||||
if stop is None:
|
||||
if len(self.speculated_text) < max_new_tokens:
|
||||
regen()
|
||||
comp = self.speculated_text[:max_new_tokens]
|
||||
self.speculated_text = self.speculated_text[max_new_tokens:]
|
||||
elif isinstance(stop, (str, list, tuple)):
|
||||
if self.speculated_text == "":
|
||||
regen()
|
||||
stop_pos = find_stop()
|
||||
if stop_pos == -1:
|
||||
stop_pos = min(
|
||||
sampling_params.max_new_tokens,
|
||||
len(self.speculated_text),
|
||||
)
|
||||
comp = self.speculated_text[:stop_pos]
|
||||
self.speculated_text = self.speculated_text[stop_pos:]
|
||||
else:
|
||||
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||
else: # Speculative execution on models with completion interface
|
||||
comp, meta_info = self._spec_gen(sampling_params)
|
||||
|
||||
self.text_ += comp
|
||||
|
||||
@@ -508,7 +517,7 @@ class StreamExecutor:
|
||||
self.variable_event[name].set()
|
||||
else:
|
||||
assert (
|
||||
self.backend.api_num_spec_tokens is None
|
||||
self.api_num_spec_tokens is None
|
||||
), "stream is not supported with api speculative execution"
|
||||
generator = self.backend.generate_stream(
|
||||
self, sampling_params=sampling_params
|
||||
@@ -571,9 +580,10 @@ class StreamExecutor:
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.backend.api_num_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
and self.api_num_spec_tokens is not None
|
||||
):
|
||||
# Execute the stored lazy generation calls
|
||||
self.backend.role_end_generate(self)
|
||||
self.cur_role = None
|
||||
|
||||
|
||||
@@ -304,6 +304,7 @@ def test_image_qa():
|
||||
temperature=0,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
|
||||
assert (
|
||||
"taxi" in state.messages()[-1]["content"]
|
||||
or "car" in state.messages()[-1]["content"]
|
||||
@@ -349,3 +350,46 @@ def test_regex():
|
||||
state = regex_gen.run()
|
||||
answer = state["answer"]
|
||||
assert re.match(regex, answer)
|
||||
|
||||
|
||||
def test_completion_speculative():
|
||||
@sgl.function(api_num_spec_tokens=64)
|
||||
def gen_character_spec(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
||||
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
||||
|
||||
|
||||
@sgl.function
|
||||
def gen_character_no_spec(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
||||
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
||||
|
||||
token_usage = sgl.global_config.default_backend.token_usage
|
||||
|
||||
token_usage.reset()
|
||||
gen_character_spec().sync()
|
||||
usage_with_spec = token_usage.prompt_tokens
|
||||
|
||||
token_usage.reset()
|
||||
gen_character_no_spec().sync()
|
||||
usage_with_no_spec = token_usage.prompt_tokens
|
||||
|
||||
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
|
||||
|
||||
|
||||
def test_chat_completion_speculative():
|
||||
@sgl.function(api_num_spec_tokens=256)
|
||||
def gen_character_spec(s):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
s += sgl.user("Construct a character within the following format:")
|
||||
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
|
||||
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||
|
||||
gen_character_spec().sync()
|
||||
Reference in New Issue
Block a user