Fix openai speculative execution (#456)

This commit is contained in:
Ying Sheng
2024-05-20 17:01:13 -07:00
committed by GitHub
parent ec380dfd30
commit 3e684be7a3
7 changed files with 243 additions and 128 deletions

View File

@@ -304,6 +304,7 @@ def test_image_qa():
temperature=0,
max_new_tokens=64,
)
assert (
"taxi" in state.messages()[-1]["content"]
or "car" in state.messages()[-1]["content"]
@@ -349,3 +350,46 @@ def test_regex():
state = regex_gen.run()
answer = state["answer"]
assert re.match(regex, answer)
def test_completion_speculative():
@sgl.function(api_num_spec_tokens=64)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage
token_usage.reset()
gen_character_spec().sync()
usage_with_spec = token_usage.prompt_tokens
token_usage.reset()
gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative():
@sgl.function(api_num_spec_tokens=256)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
gen_character_spec().sync()