Fix openai speculative execution (#456)
This commit is contained in:
@@ -1,22 +1,25 @@
|
||||
"""
|
||||
Usage:
|
||||
***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution.
|
||||
***Note: for speculative execution to work, user must put all "gen" in "assistant".
|
||||
Show in "assistant" the desired answer format. Each "gen" term should have a stop token.
|
||||
The stream mode is not supported in speculative execution.
|
||||
|
||||
E.g.
|
||||
correct:
|
||||
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||
incorrect:
|
||||
incorrect:
|
||||
s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
|
||||
s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
|
||||
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
|
||||
|
||||
export OPENAI_API_KEY=sk-******
|
||||
python3 openaichat_speculative.py
|
||||
python3 openai_chat_speculative.py
|
||||
"""
|
||||
import sglang as sgl
|
||||
from sglang import function, gen, set_default_backend, OpenAI
|
||||
from sglang import function, set_default_backend, OpenAI
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=512)
|
||||
@function(api_num_spec_tokens=256)
|
||||
def gen_character_spec(s):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
s += sgl.user("Construct a character within the following format:")
|
||||
@@ -25,7 +28,7 @@ def gen_character_spec(s):
|
||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=512)
|
||||
@function(api_num_spec_tokens=256)
|
||||
def gen_character_spec_no_few_shot(s):
|
||||
s += sgl.user("Construct a character. For each field stop with a newline\n")
|
||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||
@@ -44,18 +47,21 @@ def multi_turn_question(s, question_1, question_2):
|
||||
s += sgl.user("Answer questions in the following format:")
|
||||
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n")
|
||||
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n")
|
||||
s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2)
|
||||
s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n"))
|
||||
s += sgl.user("Question 1: " + question_1+"\nQuestion 2: " + question_2)
|
||||
s += sgl.assistant("Answer 1: " + sgl.gen("answer_1", stop="\n") + "\nAnswer 2: " + sgl.gen("answer_2", stop="\n"))
|
||||
|
||||
|
||||
def test_spec_single_turn():
|
||||
backend.token_usage.reset()
|
||||
|
||||
state = gen_character_spec.run()
|
||||
for m in state.messages():
|
||||
print(m["role"], ":", m["content"])
|
||||
|
||||
print("\n-- name:", state["name"])
|
||||
print("\n-- birthday:", state["birthday"])
|
||||
print("\n-- job:", state["job"])
|
||||
print("-- birthday:", state["birthday"])
|
||||
print("-- job:", state["job"])
|
||||
print(backend.token_usage)
|
||||
|
||||
|
||||
def test_inaccurate_spec_single_turn():
|
||||
@@ -99,7 +105,8 @@ def test_spec_multi_turn_stream():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_default_backend(OpenAI("gpt-4-turbo"))
|
||||
backend = OpenAI("gpt-4-turbo")
|
||||
set_default_backend(backend)
|
||||
|
||||
print("\n========== test spec single turn ==========\n")
|
||||
# expect reasonable answer for each field
|
||||
@@ -119,5 +126,4 @@ if __name__ == "__main__":
|
||||
|
||||
print("\n========== test spec multi turn stream ==========\n")
|
||||
# expect error in stream_executor: stream is not supported...
|
||||
test_spec_multi_turn_stream()
|
||||
|
||||
test_spec_multi_turn_stream()
|
||||
@@ -5,7 +5,7 @@ python3 openai_speculative.py
|
||||
from sglang import function, gen, set_default_backend, OpenAI
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=512)
|
||||
@function(api_num_spec_tokens=64)
|
||||
def gen_character_spec(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
@@ -14,6 +14,15 @@ def gen_character_spec(s):
|
||||
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||
|
||||
|
||||
@function
|
||||
def gen_character_no_spec(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
||||
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=64)
|
||||
def gen_character_spec_no_few_shot(s):
|
||||
# s += "Construct a character with name, birthday, and job:\n"
|
||||
@@ -22,17 +31,19 @@ def gen_character_spec_no_few_shot(s):
|
||||
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||
|
||||
|
||||
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
||||
if __name__ == "__main__":
|
||||
backend = OpenAI("gpt-3.5-turbo-instruct")
|
||||
set_default_backend(backend)
|
||||
|
||||
state = gen_character_spec.run()
|
||||
for function in [gen_character_spec, gen_character_no_spec, gen_character_spec_no_few_shot]:
|
||||
backend.token_usage.reset()
|
||||
|
||||
print("...name:", state["name"])
|
||||
print("...birthday:", state["birthday"])
|
||||
print("...job:", state["job"])
|
||||
print(f"function: {function.func.__name__}")
|
||||
|
||||
state = gen_character_spec_no_few_shot.run()
|
||||
|
||||
print("\n...name:", state["name"])
|
||||
print("...birthday:", state["birthday"])
|
||||
print("...job:", state["job"])
|
||||
state = function.run()
|
||||
|
||||
print("...name:", state["name"])
|
||||
print("...birthday:", state["birthday"])
|
||||
print("...job:", state["job"])
|
||||
print(backend.token_usage)
|
||||
print()
|
||||
Reference in New Issue
Block a user