53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
import unittest
|
|
|
|
import sglang as sgl
|
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
|
|
|
|
|
class TestBind(unittest.TestCase):
|
|
backend = None
|
|
|
|
def setUp(self):
|
|
cls = type(self)
|
|
|
|
if cls.backend is None:
|
|
cls.backend = RuntimeEndpoint(base_url="http://localhost:30000")
|
|
|
|
def test_bind(self):
|
|
@sgl.function
|
|
def few_shot_qa(s, prompt, question):
|
|
s += prompt
|
|
s += "Q: What is the capital of France?\n"
|
|
s += "A: Paris\n"
|
|
s += "Q: " + question + "\n"
|
|
s += "A:" + sgl.gen("answer", stop="\n")
|
|
|
|
few_shot_qa_2 = few_shot_qa.bind(
|
|
prompt="The following are questions with answers.\n\n"
|
|
)
|
|
|
|
tracer = few_shot_qa_2.trace()
|
|
print(tracer.last_node.print_graph_dfs() + "\n")
|
|
|
|
def test_cache(self):
|
|
@sgl.function
|
|
def few_shot_qa(s, prompt, question):
|
|
s += prompt
|
|
s += "Q: What is the capital of France?\n"
|
|
s += "A: Paris\n"
|
|
s += "Q: " + question + "\n"
|
|
s += "A:" + sgl.gen("answer", stop="\n")
|
|
|
|
few_shot_qa_2 = few_shot_qa.bind(
|
|
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
|
|
)
|
|
few_shot_qa_2.cache(self.backend)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(warnings="ignore")
|
|
|
|
# t = TestBind()
|
|
# t.setUp()
|
|
# t.test_cache()
|