Fix openai speculative execution (#456)
This commit is contained in:
@@ -1,22 +1,25 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution.
|
***Note: for speculative execution to work, user must put all "gen" in "assistant".
|
||||||
|
Show in "assistant" the desired answer format. Each "gen" term should have a stop token.
|
||||||
|
The stream mode is not supported in speculative execution.
|
||||||
|
|
||||||
E.g.
|
E.g.
|
||||||
correct:
|
correct:
|
||||||
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
incorrect:
|
incorrect:
|
||||||
s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
|
s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
|
||||||
s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
|
s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
|
||||||
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
|
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
|
|
||||||
export OPENAI_API_KEY=sk-******
|
export OPENAI_API_KEY=sk-******
|
||||||
python3 openaichat_speculative.py
|
python3 openai_chat_speculative.py
|
||||||
"""
|
"""
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang import function, gen, set_default_backend, OpenAI
|
from sglang import function, set_default_backend, OpenAI
|
||||||
|
|
||||||
|
|
||||||
@function(api_num_spec_tokens=512)
|
@function(api_num_spec_tokens=256)
|
||||||
def gen_character_spec(s):
|
def gen_character_spec(s):
|
||||||
s += sgl.system("You are a helpful assistant.")
|
s += sgl.system("You are a helpful assistant.")
|
||||||
s += sgl.user("Construct a character within the following format:")
|
s += sgl.user("Construct a character within the following format:")
|
||||||
@@ -25,7 +28,7 @@ def gen_character_spec(s):
|
|||||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
|
|
||||||
|
|
||||||
@function(api_num_spec_tokens=512)
|
@function(api_num_spec_tokens=256)
|
||||||
def gen_character_spec_no_few_shot(s):
|
def gen_character_spec_no_few_shot(s):
|
||||||
s += sgl.user("Construct a character. For each field stop with a newline\n")
|
s += sgl.user("Construct a character. For each field stop with a newline\n")
|
||||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
@@ -44,18 +47,21 @@ def multi_turn_question(s, question_1, question_2):
|
|||||||
s += sgl.user("Answer questions in the following format:")
|
s += sgl.user("Answer questions in the following format:")
|
||||||
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n")
|
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n")
|
||||||
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n")
|
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n")
|
||||||
s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2)
|
s += sgl.user("Question 1: " + question_1+"\nQuestion 2: " + question_2)
|
||||||
s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n"))
|
s += sgl.assistant("Answer 1: " + sgl.gen("answer_1", stop="\n") + "\nAnswer 2: " + sgl.gen("answer_2", stop="\n"))
|
||||||
|
|
||||||
|
|
||||||
def test_spec_single_turn():
|
def test_spec_single_turn():
|
||||||
|
backend.token_usage.reset()
|
||||||
|
|
||||||
state = gen_character_spec.run()
|
state = gen_character_spec.run()
|
||||||
for m in state.messages():
|
for m in state.messages():
|
||||||
print(m["role"], ":", m["content"])
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
print("\n-- name:", state["name"])
|
print("\n-- name:", state["name"])
|
||||||
print("\n-- birthday:", state["birthday"])
|
print("-- birthday:", state["birthday"])
|
||||||
print("\n-- job:", state["job"])
|
print("-- job:", state["job"])
|
||||||
|
print(backend.token_usage)
|
||||||
|
|
||||||
|
|
||||||
def test_inaccurate_spec_single_turn():
|
def test_inaccurate_spec_single_turn():
|
||||||
@@ -99,7 +105,8 @@ def test_spec_multi_turn_stream():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
set_default_backend(OpenAI("gpt-4-turbo"))
|
backend = OpenAI("gpt-4-turbo")
|
||||||
|
set_default_backend(backend)
|
||||||
|
|
||||||
print("\n========== test spec single turn ==========\n")
|
print("\n========== test spec single turn ==========\n")
|
||||||
# expect reasonable answer for each field
|
# expect reasonable answer for each field
|
||||||
@@ -119,5 +126,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
print("\n========== test spec multi turn stream ==========\n")
|
print("\n========== test spec multi turn stream ==========\n")
|
||||||
# expect error in stream_executor: stream is not supported...
|
# expect error in stream_executor: stream is not supported...
|
||||||
test_spec_multi_turn_stream()
|
test_spec_multi_turn_stream()
|
||||||
|
|
||||||
@@ -5,7 +5,7 @@ python3 openai_speculative.py
|
|||||||
from sglang import function, gen, set_default_backend, OpenAI
|
from sglang import function, gen, set_default_backend, OpenAI
|
||||||
|
|
||||||
|
|
||||||
@function(api_num_spec_tokens=512)
|
@function(api_num_spec_tokens=64)
|
||||||
def gen_character_spec(s):
|
def gen_character_spec(s):
|
||||||
s += "Construct a character within the following format:\n"
|
s += "Construct a character within the following format:\n"
|
||||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||||
@@ -14,6 +14,15 @@ def gen_character_spec(s):
|
|||||||
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
@function
|
||||||
|
def gen_character_no_spec(s):
|
||||||
|
s += "Construct a character within the following format:\n"
|
||||||
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||||
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
|
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
||||||
|
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
|
||||||
@function(api_num_spec_tokens=64)
|
@function(api_num_spec_tokens=64)
|
||||||
def gen_character_spec_no_few_shot(s):
|
def gen_character_spec_no_few_shot(s):
|
||||||
# s += "Construct a character with name, birthday, and job:\n"
|
# s += "Construct a character with name, birthday, and job:\n"
|
||||||
@@ -22,17 +31,19 @@ def gen_character_spec_no_few_shot(s):
|
|||||||
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
|
||||||
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
if __name__ == "__main__":
|
||||||
|
backend = OpenAI("gpt-3.5-turbo-instruct")
|
||||||
|
set_default_backend(backend)
|
||||||
|
|
||||||
state = gen_character_spec.run()
|
for function in [gen_character_spec, gen_character_no_spec, gen_character_spec_no_few_shot]:
|
||||||
|
backend.token_usage.reset()
|
||||||
|
|
||||||
print("...name:", state["name"])
|
print(f"function: {function.func.__name__}")
|
||||||
print("...birthday:", state["birthday"])
|
|
||||||
print("...job:", state["job"])
|
|
||||||
|
|
||||||
state = gen_character_spec_no_few_shot.run()
|
state = function.run()
|
||||||
|
|
||||||
print("\n...name:", state["name"])
|
|
||||||
print("...birthday:", state["birthday"])
|
|
||||||
print("...job:", state["job"])
|
|
||||||
|
|
||||||
|
print("...name:", state["name"])
|
||||||
|
print("...birthday:", state["birthday"])
|
||||||
|
print("...job:", state["job"])
|
||||||
|
print(backend.token_usage)
|
||||||
|
print()
|
||||||
@@ -9,7 +9,6 @@ class BaseBackend:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.support_concate_and_append = False
|
self.support_concate_and_append = False
|
||||||
self.chat_template = get_chat_template("default")
|
self.chat_template = get_chat_template("default")
|
||||||
self.api_num_spec_tokens = None
|
|
||||||
|
|
||||||
def get_model_name(self):
|
def get_model_name(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
import dataclasses
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -42,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TokenUsage:
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.prompt_tokens = self.completion_tokens = 0
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseBackend):
|
class OpenAI(BaseBackend):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -83,66 +93,73 @@ class OpenAI(BaseBackend):
|
|||||||
|
|
||||||
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
self.token_usage = TokenUsage(0, 0)
|
||||||
|
|
||||||
|
# API speculative execution
|
||||||
|
# TODO(ying): This does not support multi-threading (run_batch)
|
||||||
self.spec_kwargs = {}
|
self.spec_kwargs = {}
|
||||||
self.spec_format = []
|
self.spec_format = []
|
||||||
self.spec_max_num_tries = 3
|
self.spec_max_num_tries = 3
|
||||||
self.api_num_spec_tokens = None
|
|
||||||
|
|
||||||
def set_api_num_spec_tokens(self, num):
|
|
||||||
self.api_num_spec_tokens = num
|
|
||||||
|
|
||||||
def get_chat_template(self):
|
def get_chat_template(self):
|
||||||
return self.chat_template
|
return self.chat_template
|
||||||
|
|
||||||
|
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
|
||||||
|
api_num_spec_tokens: int, spec_var_name: str):
|
||||||
|
if "max_tokens" not in self.spec_kwargs:
|
||||||
|
self.spec_kwargs["max_tokens"] = api_num_spec_tokens
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
self.spec_kwargs["max_tokens"] == api_num_spec_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
params = sampling_params.to_openai_kwargs()
|
||||||
|
for key, value in params.items():
|
||||||
|
if key in ["stop"]:
|
||||||
|
continue
|
||||||
|
if key in ["max_tokens"]:
|
||||||
|
warnings.warn(
|
||||||
|
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if key not in self.spec_kwargs:
|
||||||
|
self.spec_kwargs[key] = value
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
value == self.spec_kwargs[key]
|
||||||
|
), "sampling parameters should be consistent if turn on api speculative execution."
|
||||||
|
self.spec_format.append(
|
||||||
|
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
||||||
|
)
|
||||||
|
return "", {}
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SglSamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
name=None,
|
spec_var_name: str = None,
|
||||||
):
|
):
|
||||||
if sampling_params.dtype is None:
|
if sampling_params.dtype is None:
|
||||||
if self.is_chat_model:
|
if self.is_chat_model:
|
||||||
if self.api_num_spec_tokens is None:
|
if s.api_num_spec_tokens is None:
|
||||||
if not s.text_.endswith(self.chat_prefix):
|
if not s.text_.endswith(self.chat_prefix):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"This use case is not supported if api speculative execution is off. "
|
"This use case is not supported if api speculative execution is off. "
|
||||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant."
|
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
||||||
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
|
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
|
||||||
)
|
)
|
||||||
prompt = s.messages_
|
prompt = s.messages_
|
||||||
else:
|
else:
|
||||||
# collect assistant answer format
|
return self._prepare_spec_execution(sampling_params,
|
||||||
if "max_tokens" not in self.spec_kwargs:
|
s.api_num_spec_tokens, spec_var_name)
|
||||||
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
|
|
||||||
)
|
|
||||||
params = sampling_params.to_openai_kwargs()
|
|
||||||
for key, value in params.items():
|
|
||||||
if key in ["stop"]:
|
|
||||||
continue
|
|
||||||
if key in ["max_tokens"]:
|
|
||||||
warnings.warn(
|
|
||||||
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
if key not in self.spec_kwargs:
|
|
||||||
self.spec_kwargs[key] = value
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
value == self.spec_kwargs[key]
|
|
||||||
), "sampling parameters should be consistent if turn on api speculative execution."
|
|
||||||
self.spec_format.append(
|
|
||||||
{"text": "", "stop": params["stop"], "name": name}
|
|
||||||
)
|
|
||||||
return "", {}
|
|
||||||
else:
|
else:
|
||||||
prompt = s.text_
|
prompt = s.text_
|
||||||
|
|
||||||
kwargs = sampling_params.to_openai_kwargs()
|
kwargs = sampling_params.to_openai_kwargs()
|
||||||
comp = openai_completion(
|
comp = openai_completion(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
token_usage=self.token_usage,
|
||||||
is_chat=self.is_chat_model,
|
is_chat=self.is_chat_model,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -156,6 +173,7 @@ class OpenAI(BaseBackend):
|
|||||||
kwargs.pop("stop")
|
kwargs.pop("stop")
|
||||||
comp = openai_completion(
|
comp = openai_completion(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
token_usage=self.token_usage,
|
||||||
is_chat=self.is_chat_model,
|
is_chat=self.is_chat_model,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=s.text_ + '"',
|
prompt=s.text_ + '"',
|
||||||
@@ -171,6 +189,7 @@ class OpenAI(BaseBackend):
|
|||||||
kwargs.pop("stop")
|
kwargs.pop("stop")
|
||||||
comp = openai_completion(
|
comp = openai_completion(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
token_usage=self.token_usage,
|
||||||
is_chat=self.is_chat_model,
|
is_chat=self.is_chat_model,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=s.text_,
|
prompt=s.text_,
|
||||||
@@ -211,14 +230,16 @@ class OpenAI(BaseBackend):
|
|||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
):
|
):
|
||||||
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||||
return
|
return
|
||||||
|
|
||||||
comp = ""
|
comp = ""
|
||||||
if not all(x["name"] is None for x in self.spec_format):
|
if not all(x["name"] is None for x in self.spec_format):
|
||||||
|
# TODO(ying): throw errors or warnings
|
||||||
for i in range(self.spec_max_num_tries):
|
for i in range(self.spec_max_num_tries):
|
||||||
comp = openai_completion(
|
comp = openai_completion(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
token_usage=self.token_usage,
|
||||||
is_chat=self.is_chat_model,
|
is_chat=self.is_chat_model,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=s.messages_,
|
prompt=s.messages_,
|
||||||
@@ -228,7 +249,6 @@ class OpenAI(BaseBackend):
|
|||||||
break
|
break
|
||||||
|
|
||||||
for term in self.spec_format:
|
for term in self.spec_format:
|
||||||
stop = term["stop"] if term["stop"] is not None else ""
|
|
||||||
s.text_ += term["text"]
|
s.text_ += term["text"]
|
||||||
name = term["name"]
|
name = term["name"]
|
||||||
if name is not None:
|
if name is not None:
|
||||||
@@ -258,6 +278,7 @@ class OpenAI(BaseBackend):
|
|||||||
kwargs = sampling_params.to_openai_kwargs()
|
kwargs = sampling_params.to_openai_kwargs()
|
||||||
generator = openai_completion_stream(
|
generator = openai_completion_stream(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
token_usage=self.token_usage,
|
||||||
is_chat=self.is_chat_model,
|
is_chat=self.is_chat_model,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -303,6 +324,8 @@ class OpenAI(BaseBackend):
|
|||||||
)
|
)
|
||||||
ret_str = ret.choices[0].text
|
ret_str = ret.choices[0].text
|
||||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||||
|
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||||
|
self.token_usage.completion_tokens= ret.usage.completion_tokens
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# 1. return logits as the scores
|
# 1. return logits as the scores
|
||||||
@@ -332,7 +355,7 @@ class OpenAI(BaseBackend):
|
|||||||
return decision, scores, None, None
|
return decision, scores, None, None
|
||||||
|
|
||||||
|
|
||||||
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
try:
|
try:
|
||||||
if is_chat:
|
if is_chat:
|
||||||
@@ -346,6 +369,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
|||||||
comp = [c.text for c in ret.choices]
|
comp = [c.text for c in ret.choices]
|
||||||
else:
|
else:
|
||||||
comp = ret.choices[0].text
|
comp = ret.choices[0].text
|
||||||
|
|
||||||
|
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||||
|
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||||
break
|
break
|
||||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||||
@@ -359,16 +385,19 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
|||||||
return comp
|
return comp
|
||||||
|
|
||||||
|
|
||||||
def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
try:
|
try:
|
||||||
if is_chat:
|
if is_chat:
|
||||||
if "stop" in kwargs and kwargs["stop"] is None:
|
if "stop" in kwargs and kwargs["stop"] is None:
|
||||||
kwargs.pop("stop")
|
kwargs.pop("stop")
|
||||||
generator = client.chat.completions.create(
|
generator = client.chat.completions.create(
|
||||||
messages=prompt, stream=True, **kwargs
|
messages=prompt, stream=True, stream_options={"include_usage": True},
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
for ret in generator:
|
for ret in generator:
|
||||||
|
if len(ret.choices) == 0:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
content = ret.choices[0].delta.content
|
content = ret.choices[0].delta.content
|
||||||
except IndexError:
|
except IndexError:
|
||||||
@@ -376,11 +405,17 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
|
|||||||
yield content or "", {}
|
yield content or "", {}
|
||||||
else:
|
else:
|
||||||
generator = client.completions.create(
|
generator = client.completions.create(
|
||||||
prompt=prompt, stream=True, **kwargs
|
prompt=prompt, stream=True, stream_options={"include_usage": True},
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
for ret in generator:
|
for ret in generator:
|
||||||
|
if len(ret.choices) == 0:
|
||||||
|
continue
|
||||||
content = ret.choices[0].text
|
content = ret.choices[0].text
|
||||||
yield content or "", {}
|
yield content or "", {}
|
||||||
|
|
||||||
|
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||||
|
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||||
break
|
break
|
||||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||||
|
|||||||
@@ -196,12 +196,6 @@ class StreamExecutor:
|
|||||||
# For completion
|
# For completion
|
||||||
self.text_ = "" # The full text
|
self.text_ = "" # The full text
|
||||||
|
|
||||||
# For speculative execution
|
|
||||||
from sglang.backend.openai import OpenAI
|
|
||||||
if isinstance(backend, OpenAI):
|
|
||||||
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
|
|
||||||
self.speculated_text = ""
|
|
||||||
|
|
||||||
# For chat
|
# For chat
|
||||||
self.messages_ = [] # The messages in the OpenAI API format
|
self.messages_ = [] # The messages in the OpenAI API format
|
||||||
self.chat_template = chat_template or self.backend.get_chat_template()
|
self.chat_template = chat_template or self.backend.get_chat_template()
|
||||||
@@ -215,6 +209,10 @@ class StreamExecutor:
|
|||||||
# For fork/join
|
# For fork/join
|
||||||
self.fork_start_text_pos = None
|
self.fork_start_text_pos = None
|
||||||
|
|
||||||
|
# For speculative execution
|
||||||
|
self.api_num_spec_tokens = api_num_spec_tokens
|
||||||
|
self.speculated_text = ""
|
||||||
|
|
||||||
# Worker thread
|
# Worker thread
|
||||||
self.use_thread = use_thread
|
self.use_thread = use_thread
|
||||||
if self.use_thread:
|
if self.use_thread:
|
||||||
@@ -293,6 +291,8 @@ class StreamExecutor:
|
|||||||
exes[i].fork_start_text_pos = len(self.text_)
|
exes[i].fork_start_text_pos = len(self.text_)
|
||||||
exes[i].images_ = list(self.images_)
|
exes[i].images_ = list(self.images_)
|
||||||
|
|
||||||
|
# TODO(ying): handle API speculative execution
|
||||||
|
|
||||||
return exes
|
return exes
|
||||||
|
|
||||||
def text(self):
|
def text(self):
|
||||||
@@ -399,7 +399,7 @@ class StreamExecutor:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self.cur_role == "assistant"
|
self.cur_role == "assistant"
|
||||||
and self.backend.api_num_spec_tokens is not None
|
and self.api_num_spec_tokens is not None
|
||||||
and self.backend.is_chat_model
|
and self.backend.is_chat_model
|
||||||
and not prefix
|
and not prefix
|
||||||
):
|
):
|
||||||
@@ -435,71 +435,80 @@ class StreamExecutor:
|
|||||||
# if global_config.eager_fill_image:
|
# if global_config.eager_fill_image:
|
||||||
# self.backend.fill_image(self)
|
# self.backend.fill_image(self)
|
||||||
|
|
||||||
|
def _spec_gen(self, sampling_params):
|
||||||
|
stop = sampling_params.stop
|
||||||
|
max_new_tokens = sampling_params.max_new_tokens
|
||||||
|
meta_info = {}
|
||||||
|
|
||||||
|
def regen():
|
||||||
|
nonlocal meta_info
|
||||||
|
|
||||||
|
sampling_params.max_new_tokens = max(
|
||||||
|
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
||||||
|
)
|
||||||
|
sampling_params.stop = None
|
||||||
|
self.speculated_text, meta_info = self.backend.generate(
|
||||||
|
self, sampling_params=sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_stop():
|
||||||
|
if isinstance(stop, str):
|
||||||
|
return self.speculated_text.find(stop)
|
||||||
|
elif isinstance(stop, (tuple, list)):
|
||||||
|
pos = -1
|
||||||
|
for stop_str in stop:
|
||||||
|
stop_pos = self.speculated_text.find(stop_str)
|
||||||
|
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||||
|
pos = stop_pos
|
||||||
|
return pos
|
||||||
|
else:
|
||||||
|
raise Exception("Wrong type of stop in sampling parameters.")
|
||||||
|
|
||||||
|
if stop is None:
|
||||||
|
if len(self.speculated_text) < max_new_tokens:
|
||||||
|
regen()
|
||||||
|
comp = self.speculated_text[:max_new_tokens]
|
||||||
|
self.speculated_text = self.speculated_text[max_new_tokens:]
|
||||||
|
elif isinstance(stop, (str, list, tuple)):
|
||||||
|
if self.speculated_text == "":
|
||||||
|
regen()
|
||||||
|
stop_pos = find_stop()
|
||||||
|
if stop_pos == -1:
|
||||||
|
stop_pos = min(
|
||||||
|
sampling_params.max_new_tokens,
|
||||||
|
len(self.speculated_text),
|
||||||
|
)
|
||||||
|
comp = self.speculated_text[:stop_pos]
|
||||||
|
self.speculated_text = self.speculated_text[stop_pos:]
|
||||||
|
else:
|
||||||
|
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||||
|
|
||||||
|
return comp, meta_info
|
||||||
|
|
||||||
def _execute_gen(self, expr: SglGen):
|
def _execute_gen(self, expr: SglGen):
|
||||||
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
||||||
name = expr.name
|
name = expr.name
|
||||||
|
|
||||||
if not self.stream:
|
if not self.stream:
|
||||||
if self.backend.api_num_spec_tokens is None:
|
if self.api_num_spec_tokens is None:
|
||||||
comp, meta_info = self.backend.generate(
|
comp, meta_info = self.backend.generate(
|
||||||
self,
|
self,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
elif self.backend.is_chat_model:
|
if self.backend.is_chat_model:
|
||||||
# spec on model with only chat interface
|
# Speculative execution on models with only chat interface.
|
||||||
comp, meta_info = self.backend.generate(
|
# Store the calls into a temporary list.
|
||||||
self,
|
# They will be lazily executed later.
|
||||||
sampling_params=sampling_params,
|
comp, meta_info = self.backend.generate(
|
||||||
name=name,
|
self,
|
||||||
)
|
sampling_params=sampling_params,
|
||||||
return
|
spec_var_name=name,
|
||||||
|
|
||||||
else: # spec on model with completion
|
|
||||||
stop = sampling_params.stop
|
|
||||||
max_new_tokens = sampling_params.max_new_tokens
|
|
||||||
meta_info = {}
|
|
||||||
|
|
||||||
def regen():
|
|
||||||
sampling_params.max_new_tokens = max(
|
|
||||||
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
|
|
||||||
)
|
|
||||||
sampling_params.stop = None
|
|
||||||
self.speculated_text, meta_info = self.backend.generate(
|
|
||||||
self, sampling_params=sampling_params
|
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
def find_stop():
|
else: # Speculative execution on models with completion interface
|
||||||
if isinstance(stop, str):
|
comp, meta_info = self._spec_gen(sampling_params)
|
||||||
return self.speculated_text.find(stop)
|
|
||||||
elif isinstance(stop, (tuple, list)):
|
|
||||||
pos = -1
|
|
||||||
for stop_str in stop:
|
|
||||||
stop_pos = self.speculated_text.find(stop_str)
|
|
||||||
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
|
||||||
pos = stop_pos
|
|
||||||
return pos
|
|
||||||
else:
|
|
||||||
raise Exception("Wrong type of stop in sampling parameters.")
|
|
||||||
|
|
||||||
if stop is None:
|
|
||||||
if len(self.speculated_text) < max_new_tokens:
|
|
||||||
regen()
|
|
||||||
comp = self.speculated_text[:max_new_tokens]
|
|
||||||
self.speculated_text = self.speculated_text[max_new_tokens:]
|
|
||||||
elif isinstance(stop, (str, list, tuple)):
|
|
||||||
if self.speculated_text == "":
|
|
||||||
regen()
|
|
||||||
stop_pos = find_stop()
|
|
||||||
if stop_pos == -1:
|
|
||||||
stop_pos = min(
|
|
||||||
sampling_params.max_new_tokens,
|
|
||||||
len(self.speculated_text),
|
|
||||||
)
|
|
||||||
comp = self.speculated_text[:stop_pos]
|
|
||||||
self.speculated_text = self.speculated_text[stop_pos:]
|
|
||||||
else:
|
|
||||||
raise ValueError("Wrong type of stop in sampling parameters.")
|
|
||||||
|
|
||||||
self.text_ += comp
|
self.text_ += comp
|
||||||
|
|
||||||
@@ -508,7 +517,7 @@ class StreamExecutor:
|
|||||||
self.variable_event[name].set()
|
self.variable_event[name].set()
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
self.backend.api_num_spec_tokens is None
|
self.api_num_spec_tokens is None
|
||||||
), "stream is not supported with api speculative execution"
|
), "stream is not supported with api speculative execution"
|
||||||
generator = self.backend.generate_stream(
|
generator = self.backend.generate_stream(
|
||||||
self, sampling_params=sampling_params
|
self, sampling_params=sampling_params
|
||||||
@@ -571,9 +580,10 @@ class StreamExecutor:
|
|||||||
def _execute_role_end(self, expr: SglRoleEnd):
|
def _execute_role_end(self, expr: SglRoleEnd):
|
||||||
if (
|
if (
|
||||||
self.cur_role == "assistant"
|
self.cur_role == "assistant"
|
||||||
and self.backend.api_num_spec_tokens is not None
|
|
||||||
and self.backend.is_chat_model
|
and self.backend.is_chat_model
|
||||||
|
and self.api_num_spec_tokens is not None
|
||||||
):
|
):
|
||||||
|
# Execute the stored lazy generation calls
|
||||||
self.backend.role_end_generate(self)
|
self.backend.role_end_generate(self)
|
||||||
self.cur_role = None
|
self.cur_role = None
|
||||||
|
|
||||||
|
|||||||
@@ -304,6 +304,7 @@ def test_image_qa():
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
max_new_tokens=64,
|
max_new_tokens=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"taxi" in state.messages()[-1]["content"]
|
"taxi" in state.messages()[-1]["content"]
|
||||||
or "car" in state.messages()[-1]["content"]
|
or "car" in state.messages()[-1]["content"]
|
||||||
@@ -349,3 +350,46 @@ def test_regex():
|
|||||||
state = regex_gen.run()
|
state = regex_gen.run()
|
||||||
answer = state["answer"]
|
answer = state["answer"]
|
||||||
assert re.match(regex, answer)
|
assert re.match(regex, answer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_speculative():
|
||||||
|
@sgl.function(api_num_spec_tokens=64)
|
||||||
|
def gen_character_spec(s):
|
||||||
|
s += "Construct a character within the following format:\n"
|
||||||
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||||
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
|
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
||||||
|
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def gen_character_no_spec(s):
|
||||||
|
s += "Construct a character within the following format:\n"
|
||||||
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||||
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
|
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
||||||
|
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
token_usage = sgl.global_config.default_backend.token_usage
|
||||||
|
|
||||||
|
token_usage.reset()
|
||||||
|
gen_character_spec().sync()
|
||||||
|
usage_with_spec = token_usage.prompt_tokens
|
||||||
|
|
||||||
|
token_usage.reset()
|
||||||
|
gen_character_no_spec().sync()
|
||||||
|
usage_with_no_spec = token_usage.prompt_tokens
|
||||||
|
|
||||||
|
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion_speculative():
|
||||||
|
@sgl.function(api_num_spec_tokens=256)
|
||||||
|
def gen_character_spec(s):
|
||||||
|
s += sgl.system("You are a helpful assistant.")
|
||||||
|
s += sgl.user("Construct a character within the following format:")
|
||||||
|
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
|
||||||
|
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
||||||
|
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
|
|
||||||
|
gen_character_spec().sync()
|
||||||
@@ -14,6 +14,8 @@ from sglang.test.test_programs import (
|
|||||||
test_select,
|
test_select,
|
||||||
test_stream,
|
test_stream,
|
||||||
test_tool_use,
|
test_tool_use,
|
||||||
|
test_completion_speculative,
|
||||||
|
test_chat_completion_speculative
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,6 +80,14 @@ class TestOpenAIBackend(unittest.TestCase):
|
|||||||
set_default_backend(self.backend)
|
set_default_backend(self.backend)
|
||||||
test_stream()
|
test_stream()
|
||||||
|
|
||||||
|
def test_completion_speculative(self):
|
||||||
|
set_default_backend(self.backend)
|
||||||
|
test_completion_speculative()
|
||||||
|
|
||||||
|
def test_chat_completion_speculative(self):
|
||||||
|
set_default_backend(self.chat_backend)
|
||||||
|
test_chat_completion_speculative()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(warnings="ignore")
|
unittest.main(warnings="ignore")
|
||||||
@@ -87,4 +97,4 @@ if __name__ == "__main__":
|
|||||||
# global_config.verbosity = 2
|
# global_config.verbosity = 2
|
||||||
# t = TestOpenAIBackend()
|
# t = TestOpenAIBackend()
|
||||||
# t.setUp()
|
# t.setUp()
|
||||||
# t.test_few_shot_qa()
|
# t.test_chat_completion_speculative()
|
||||||
Reference in New Issue
Block a user