Fix openai speculative execution (#456)

This commit is contained in:
Ying Sheng
2024-05-20 17:01:13 -07:00
committed by GitHub
parent ec380dfd30
commit 3e684be7a3
7 changed files with 243 additions and 128 deletions

View File

@@ -9,7 +9,6 @@ class BaseBackend:
def __init__(self) -> None:
self.support_concate_and_append = False
self.chat_template = get_chat_template("default")
self.api_num_spec_tokens = None
def get_model_name(self):
raise NotImplementedError()

View File

@@ -1,6 +1,7 @@
import logging
import time
import warnings
import dataclasses
from typing import Callable, List, Optional, Union
import numpy as np
@@ -42,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
]
@dataclasses.dataclass
class TokenUsage:
prompt_tokens: int
completion_tokens: int
def reset(self):
self.prompt_tokens = self.completion_tokens = 0
class OpenAI(BaseBackend):
def __init__(
self,
@@ -83,66 +93,73 @@ class OpenAI(BaseBackend):
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
# Usage
self.token_usage = TokenUsage(0, 0)
# API speculative execution
# TODO(ying): This does not support multi-threading (run_batch)
self.spec_kwargs = {}
self.spec_format = []
self.spec_max_num_tries = 3
self.api_num_spec_tokens = None
def set_api_num_spec_tokens(self, num):
self.api_num_spec_tokens = num
def get_chat_template(self):
return self.chat_template
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
api_num_spec_tokens: int, spec_var_name: str):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = api_num_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == api_num_spec_tokens
)
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": spec_var_name}
)
return "", {}
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
name=None,
spec_var_name: str = None,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if self.api_num_spec_tokens is None:
if s.api_num_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant."
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
)
prompt = s.messages_
else:
# collect assistant answer format
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
)
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": name}
)
return "", {}
return self._prepare_spec_execution(sampling_params,
s.api_num_spec_tokens, spec_var_name)
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
@@ -156,6 +173,7 @@ class OpenAI(BaseBackend):
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_ + '"',
@@ -171,6 +189,7 @@ class OpenAI(BaseBackend):
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
@@ -211,14 +230,16 @@ class OpenAI(BaseBackend):
self,
s: StreamExecutor,
):
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
if not all(x["name"] is None for x in self.spec_format):
# TODO(ying): throw errors or warnings
for i in range(self.spec_max_num_tries):
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.messages_,
@@ -228,7 +249,6 @@ class OpenAI(BaseBackend):
break
for term in self.spec_format:
stop = term["stop"] if term["stop"] is not None else ""
s.text_ += term["text"]
name = term["name"]
if name is not None:
@@ -258,6 +278,7 @@ class OpenAI(BaseBackend):
kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
@@ -303,6 +324,8 @@ class OpenAI(BaseBackend):
)
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens= ret.usage.completion_tokens
# TODO:
# 1. return logits as the scores
@@ -332,7 +355,7 @@ class OpenAI(BaseBackend):
return decision, scores, None, None
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
for attempt in range(retries):
try:
if is_chat:
@@ -346,6 +369,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
@@ -359,16 +385,19 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
return comp
def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwargs):
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt, stream=True, **kwargs
messages=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
)
for ret in generator:
if len(ret.choices) == 0:
continue
try:
content = ret.choices[0].delta.content
except IndexError:
@@ -376,11 +405,17 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt, stream=True, **kwargs
prompt=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
)
for ret in generator:
if len(ret.choices) == 0:
continue
content = ret.choices[0].text
yield content or "", {}
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")

View File

@@ -196,12 +196,6 @@ class StreamExecutor:
# For completion
self.text_ = "" # The full text
# For speculative execution
from sglang.backend.openai import OpenAI
if isinstance(backend, OpenAI):
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
self.speculated_text = ""
# For chat
self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template()
@@ -215,6 +209,10 @@ class StreamExecutor:
# For fork/join
self.fork_start_text_pos = None
# For speculative execution
self.api_num_spec_tokens = api_num_spec_tokens
self.speculated_text = ""
# Worker thread
self.use_thread = use_thread
if self.use_thread:
@@ -293,6 +291,8 @@ class StreamExecutor:
exes[i].fork_start_text_pos = len(self.text_)
exes[i].images_ = list(self.images_)
# TODO(ying): handle API speculative execution
return exes
def text(self):
@@ -399,7 +399,7 @@ class StreamExecutor:
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.api_num_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
@@ -435,71 +435,80 @@ class StreamExecutor:
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
def _spec_gen(self, sampling_params):
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
nonlocal meta_info
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
return comp, meta_info
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
if self.backend.api_num_spec_tokens is None:
if self.api_num_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
elif self.backend.is_chat_model:
# spec on model with only chat interface
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
name=name,
)
return
else: # spec on model with completion
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
else:
if self.backend.is_chat_model:
# Speculative execution on models with only chat interface.
# Store the calls into a temporary list.
# They will be lazily executed later.
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
spec_var_name=name,
)
return
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)
self.text_ += comp
@@ -508,7 +517,7 @@ class StreamExecutor:
self.variable_event[name].set()
else:
assert (
self.backend.api_num_spec_tokens is None
self.api_num_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
@@ -571,9 +580,10 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
and self.api_num_spec_tokens is not None
):
# Execute the stored lazy generation calls
self.backend.role_end_generate(self)
self.cur_role = None

View File

@@ -304,6 +304,7 @@ def test_image_qa():
temperature=0,
max_new_tokens=64,
)
assert (
"taxi" in state.messages()[-1]["content"]
or "car" in state.messages()[-1]["content"]
@@ -349,3 +350,46 @@ def test_regex():
state = regex_gen.run()
answer = state["answer"]
assert re.match(regex, answer)
def test_completion_speculative():
@sgl.function(api_num_spec_tokens=64)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage
token_usage.reset()
gen_character_spec().sync()
usage_with_spec = token_usage.prompt_tokens
token_usage.reset()
gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative():
@sgl.function(api_num_spec_tokens=256)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
gen_character_spec().sync()