Files
sglang/python/sglang/test/test_programs.py
2024-01-15 01:15:53 -08:00

348 lines
11 KiB
Python

"""
This file contains the SGL programs used for unit testing.
"""
import json
import re
import sglang as sgl
def test_few_shot_qa():
@sgl.function
def few_shot_qa(s, question):
s += "The following are questions with answers.\n\n"
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: What is the capital of Germany?\n"
s += "A: Berlin\n"
s += "Q: What is the capital of Italy?\n"
s += "A: Rome\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
ret = few_shot_qa.run(question="What is the capital of the United States?")
assert "washington" in ret["answer"].strip().lower(), f"answer: {ret['answer']}"
rets = few_shot_qa.run_batch(
[
{"question": "What is the capital of Japan?"},
{"question": "What is the capital of the United Kingdom?"},
{"question": "What is the capital city of China?"},
],
temperature=0.1,
)
answers = [x["answer"].strip().lower() for x in rets]
assert answers == ["tokyo", "london", "beijing"], f"answers: {answers}"
def test_mt_bench():
@sgl.function
def answer_mt_bench(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1"))
with s.user():
s += question_2
with s.assistant():
s += sgl.gen("answer_2")
question_1 = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
question_2 = (
"Rewrite your previous response. Start every sentence with the letter A."
)
ret = answer_mt_bench.run(
question_1=question_1, question_2=question_2, temperature=0.7, max_new_tokens=64
)
assert len(ret.messages()) in [4, 5]
def test_select(check_answer):
@sgl.function
def true_or_false(s, statement):
s += "Determine whether the statement below is True, False, or Unknown.\n"
s += "Statement: The capital of France is Pairs.\n"
s += "Answer: True\n"
s += "Statement: " + statement + "\n"
s += "Answer:" + sgl.select("answer", ["True", "False", "Unknown"])
ret = true_or_false.run(
statement="The capital of Germany is Berlin.",
)
if check_answer:
assert ret["answer"] == "True", ret.text
else:
assert ret["answer"] in ["True", "False", "Unknown"]
ret = true_or_false.run(
statement="The capital of Canada is Tokyo.",
)
if check_answer:
assert ret["answer"] == "False", ret.text
else:
assert ret["answer"] in ["True", "False", "Unknown"]
ret = true_or_false.run(
statement="Purple is a better color than green.",
)
if check_answer:
assert ret["answer"] == "Unknown", ret.text
else:
assert ret["answer"] in ["True", "False", "Unknown"]
def test_decode_int():
@sgl.function
def decode_int(s):
s += "The number of hours in a day is " + sgl.gen_int("hours") + "\n"
s += "The number of days in a year is " + sgl.gen_int("days") + "\n"
ret = decode_int.run(temperature=0.1)
assert int(ret["hours"]) == 24, ret.text
assert int(ret["days"]) == 365, ret.text
def test_decode_json_regex():
@sgl.function
def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
s += "Generate a JSON object to describe the basic information of a city.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n"
s += ' "country": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
s += ' "timezone": ' + sgl.gen(regex=REGEX_STRING) + "\n"
s += "}"
ret = decode_json.run()
js_obj = json.loads(ret["json_output"])
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_decode_json():
@sgl.function
def decode_json(s):
s += "Generate a JSON object to describe the basic information of a city.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen_string() + ",\n"
s += ' "population": ' + sgl.gen_int() + ",\n"
s += ' "area": ' + sgl.gen(dtype=int) + ",\n"
s += ' "country": ' + sgl.gen_string() + ",\n"
s += ' "timezone": ' + sgl.gen(dtype=str) + "\n"
s += "}"
ret = decode_json.run()
js_obj = json.loads(ret["json_output"])
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_expert_answer():
@sgl.function
def expert_answer(s, question):
s += "Question: " + question + "\n"
s += (
"A good person to answer this question is"
+ sgl.gen("expert", stop=[".", "\n"])
+ ".\n"
)
s += (
"For example,"
+ s["expert"]
+ " would answer that "
+ sgl.gen("answer", stop=".")
+ "."
)
ret = expert_answer.run(question="What is the capital of France?", temperature=0.1)
assert "paris" in ret.text().lower()
def test_tool_use():
def calculate(expression):
return f"{eval(expression)}"
@sgl.function
def tool_use(s, lhs, rhs):
s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n"
s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n"
s += "Question: What is the product of " + str(lhs) + " and " + str(rhs) + "?\n"
s += (
"Answer: The answer is calculate("
+ sgl.gen("expression", stop=")")
+ ") = "
)
with s.var_scope("answer"):
s += calculate(s["expression"])
lhs, rhs = 257, 983
ret = tool_use(lhs=lhs, rhs=rhs, temperature=0)
assert int(ret["answer"]) == lhs * rhs
def test_react():
@sgl.function
def react(s, question):
s += """
Question: Which country does the founder of Microsoft live in?
Thought 1: I need to search for the founder of Microsoft.
Action 1: Search [Founder of Microsoft].
Observation 1: The founder of Microsoft is Bill Gates.
Thought 2: I need to search for the country where Bill Gates lives in.
Action 2: Search [Where does Bill Gates live].
Observation 2: Bill Gates lives in the United States.
Thought 3: The answer is the United States.
Action 3: Finish [United States].\n
"""
s += "Question: " + question + "\n"
for i in range(1, 5):
s += f"Thought {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
s += f"Action {i}: " + sgl.select(f"action_{i}", ["Search", "Finish"])
if s[f"action_{i}"] == "Search":
s += " [" + sgl.gen(stop="]") + "].\n"
s += f"Observation {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
else:
s += " [" + sgl.gen("answer", stop="]") + "].\n"
break
ret = react.run(
question="What country does the creator of Linux live in?",
temperature=0.1,
)
answer = ret["answer"].lower()
assert "finland" in answer or "states" in answer
def test_parallel_decoding():
max_tokens = 64
number = 5
@sgl.function
def parallel_decoding(s, topic):
s += "Act as a helpful assistant.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
# Generate skeleton
for i in range(1, 1 + number):
s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
# Generate detailed tips
forks = s.fork(number)
for i in range(number):
forks[
i
] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"])
forks.join()
# Concatenate tips and summarize
s += "Here are these tips with detailed explanation:\n"
for i in range(number):
s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
def test_parallel_encoding(check_answer=True):
max_tokens = 64
@sgl.function
def parallel_encoding(s, question, context_0, context_1, context_2):
s += "USER: I will ask a question based on some statements.\n"
s += "ASSISTANT: Sure. I will give the answer.\n"
s += "USER: Please memorize these statements.\n"
contexts = [context_0, context_1, context_2]
forks = s.fork(len(contexts))
forks += lambda i: f"Statement {i}: " + contexts[i] + "\n"
forks.join(mode="concate_and_append")
s += "Now, please answer the following question. " "Do not list options."
s += "\nQuestion: " + question + "\n"
s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens)
ret = parallel_encoding.run(
question="Who is the father of Julian?",
context_0="Ethan is the father of Liam.",
context_1="Noah is the father of Julian.",
context_2="Oliver is the father of Carlos.",
temperature=0,
)
answer = ret["answer"]
if check_answer:
assert "Noah" in answer
def test_image_qa():
@sgl.function
def image_qa(s, question):
s += sgl.user(sgl.image("image.png") + question)
s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run(
question="Please describe this image in simple words.",
temperature=0,
max_new_tokens=64,
)
assert "taxi" in state.messages()[-1]["content"]
def test_stream():
@sgl.function
def qa(s, question):
s += sgl.user(question)
s += sgl.assistant(sgl.gen("answer"))
ret = qa(
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
stream=True,
)
out = ""
for chunk in ret.text_iter():
out += chunk
ret = qa(
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
stream=True,
)
out = ""
for chunk in ret.text_iter("answer"):
out += chunk
def test_regex():
regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@sgl.function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
temperature=0,
regex=regex,
)
state = regex_gen.run()
answer = state["answer"]
assert re.match(regex, answer)