From 97aa9b3284566a4d84c08f7c1fee3699bf694e3d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 30 Jan 2024 05:45:27 -0800 Subject: [PATCH] Improve docs & Add JSON decode example (#121) --- README.md | 43 +++++++++-- benchmark/json_fast_forward/bench_other.py | 2 +- benchmark/json_fast_forward/bench_sglang.py | 2 +- docs/sampling_params.md | 12 +++ examples/quick_start/srt_example_llava.py | 3 + examples/usage/async_io.py | 8 +- examples/usage/choices_logprob.py | 3 +- examples/usage/json_decode.py | 81 +++++++++++++++++++++ examples/usage/openai_speculative.py | 4 + examples/usage/parallel_sample.py | 4 + examples/usage/readme_examples.py | 34 +++++++-- examples/usage/srt_example_regex.py | 24 ------ examples/usage/streaming.py | 4 + python/pyproject.toml | 2 +- python/sglang/backend/openai.py | 25 ++----- python/sglang/srt/managers/io_struct.py | 7 ++ python/sglang/test/test_programs.py | 2 +- test/srt/test_curl.sh | 9 +++ test/srt/test_httpserver_llava.py | 4 +- 19 files changed, 212 insertions(+), 61 deletions(-) create mode 100644 examples/usage/json_decode.py delete mode 100644 examples/usage/srt_example_regex.py create mode 100644 test/srt/test_curl.sh diff --git a/README.md b/README.md index d5892f131..c33c6512f 100644 --- a/README.md +++ b/README.md @@ -123,19 +123,21 @@ You can implement your prompt flow in a function decorated by `sgl.function`. You can then invoke the function with `run` or `run_batch`. The system will manage the state, chat template, parallelism and batching for you. +The complete code for the examples below can be found at [readme_examples.py](examples/usage/readme_examples.py) + ### Control Flow You can use any Python code within the function body, including control flow, nested function calls, and external libraries. ```python @sgl.function -def control_flow(s, question): - s += "To answer this question: " + question + ", " - s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". " +def tool_use(s, question): + s += "To answer this question: " + question + ". " + s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". " if s["tool"] == "calculator": s += "The math expression is" + sgl.gen("expression") - elif s["tool"] == "web browser": - s += "The website url is" + sgl.gen("url") + elif s["tool"] == "search engine": + s += "The key word to search is" + sgl.gen("word") ``` ### Parallelism @@ -170,6 +172,8 @@ def image_qa(s, image_file, question): s += sgl.assistant(sgl.gen("answer", max_tokens=256) ``` +See also [srt_example_llava.py](examples/quick_start/srt_example_llava.py). + ### Constrained Decoding Use `regex` to specify a regular expression as a decoding constraint. This is only supported for local models. @@ -185,6 +189,35 @@ def regular_expression_gen(s): ) ``` +### JSON Decoding + +```python +character_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" + + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" + + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" + + r""" "wand": \{\n""" + + r""" "wood": "[\w\d\s]{1,16}",\n""" + + r""" "core": "[\w\d\s]{1,16}",\n""" + + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" + + r""" \},\n""" + + r""" "alive": "(Alive|Deceased)",\n""" + + r""" "patronus": "[\w\d\s]{1,16}",\n""" + + r""" "bogart": "[\w\d\s]{1,16}"\n""" + + r"""\}""" +) + +@sgl.function +def character_gen(s, name): + s += name + " is a character in Harry Potter. Please fill in the following information about him/her.\n" + s += sgl.gen("json_output", max_tokens=256, regex=character_regex) +``` + +See also [json_decode.py](examples/usage/json_decode.py). + + ### Batching Use `run_batch` to run a batch of requests with continuous batching. diff --git a/benchmark/json_fast_forward/bench_other.py b/benchmark/json_fast_forward/bench_other.py index 7db0e2d21..7052c148a 100644 --- a/benchmark/json_fast_forward/bench_other.py +++ b/benchmark/json_fast_forward/bench_other.py @@ -34,7 +34,7 @@ character_regex = ( # fmt: off def character_gen(name, generate): - s = name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n" + s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n" s += generate(s, max_tokens=256, regex=character_regex) return s # fmt: on diff --git a/benchmark/json_fast_forward/bench_sglang.py b/benchmark/json_fast_forward/bench_sglang.py index 6f8c94f17..d0ef786fb 100644 --- a/benchmark/json_fast_forward/bench_sglang.py +++ b/benchmark/json_fast_forward/bench_sglang.py @@ -32,7 +32,7 @@ character_regex = ( # fmt: off @sgl.function def character_gen(s, name): - s += name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n" + s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n" s += sgl.gen("json_output", max_tokens=256, regex=character_regex) # fmt: on diff --git a/docs/sampling_params.md b/docs/sampling_params.md index 07d07853d..f5e7cadb0 100644 --- a/docs/sampling_params.md +++ b/docs/sampling_params.md @@ -4,13 +4,21 @@ This doc describes the sampling parameters of the SGLang Runtime. The `/generate` endpoint accepts the following arguments in the JSON format. ```python +@dataclass class GenerateReqInput: + # The input prompt text: Union[List[str], str] + # The image input image_data: Optional[Union[List[str], str]] = None + # The sampling_params sampling_params: Union[List[Dict], Dict] = None + # The request id rid: Optional[Union[List[str], str]] = None + # Whether return logprobs of the prompts return_logprob: Optional[Union[List[bool], bool]] = None + # The start location of the prompt for return_logprob logprob_start_len: Optional[Union[List[int], int]] = None + # Whether to stream output stream: bool = False ``` @@ -84,3 +92,7 @@ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): prev = len(output) print("") ``` + +### Multi modal + +See [test_httpserver_llava.py](../test/srt/test_httpserver_llava.py). diff --git a/examples/quick_start/srt_example_llava.py b/examples/quick_start/srt_example_llava.py index f374d0a0f..b2b51622e 100644 --- a/examples/quick_start/srt_example_llava.py +++ b/examples/quick_start/srt_example_llava.py @@ -46,6 +46,9 @@ if __name__ == "__main__": runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.5-7b", tokenizer_path="llava-hf/llava-1.5-7b-hf") sgl.set_default_backend(runtime) + # Or you can use API models + # sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview")) + # sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision")) # Run a single request print("\n========== single ==========\n") diff --git a/examples/usage/async_io.py b/examples/usage/async_io.py index bf5dbd79a..68714812f 100644 --- a/examples/usage/async_io.py +++ b/examples/usage/async_io.py @@ -1,3 +1,7 @@ +""" +Usage: +python3 async_io.py +""" import asyncio from sglang import Runtime @@ -27,8 +31,8 @@ async def generate( if __name__ == "__main__": runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") - print("runtime ready") - + print("--- runtime ready ---\n") + prompt = "Who is Alan Turing?" sampling_params = {"max_new_tokens": 128} asyncio.run(generate(runtime, prompt, sampling_params)) diff --git a/examples/usage/choices_logprob.py b/examples/usage/choices_logprob.py index 3b5254dd0..6fb28940c 100644 --- a/examples/usage/choices_logprob.py +++ b/examples/usage/choices_logprob.py @@ -1,7 +1,8 @@ """ +Usage: python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +python choices_logprob.py """ - import sglang as sgl diff --git a/examples/usage/json_decode.py b/examples/usage/json_decode.py new file mode 100644 index 000000000..96d10e536 --- /dev/null +++ b/examples/usage/json_decode.py @@ -0,0 +1,81 @@ +""" +Usage: +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +python json_decode.py +""" +from enum import Enum + +from pydantic import BaseModel, constr +import sglang as sgl +from sglang.srt.constrained.json_schema import build_regex_from_object + + +character_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" + + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" + + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" + + r""" "wand": \{\n""" + + r""" "wood": "[\w\d\s]{1,16}",\n""" + + r""" "core": "[\w\d\s]{1,16}",\n""" + + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" + + r""" \},\n""" + + r""" "alive": "(Alive|Deceased)",\n""" + + r""" "patronus": "[\w\d\s]{1,16}",\n""" + + r""" "bogart": "[\w\d\s]{1,16}"\n""" + + r"""\}""" +) + + +@sgl.function +def character_gen(s, name): + s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n" + s += sgl.gen("json_output", max_tokens=256, regex=character_regex) + + +def driver_character_gen(): + state = character_gen.run(name="Hermione Granger") + print(state.text()) + + +class Weapon(str, Enum): + sword = "sword" + axe = "axe" + mace = "mace" + spear = "spear" + bow = "bow" + crossbow = "crossbow" + + +class Wizard(BaseModel): + name: str + age: int + weapon: Weapon + + +@sgl.function +def pydantic_wizard_gen(s): + s += "Give me a description about a wizard in the JSON format.\n" + s += sgl.gen( + "character", + max_tokens=128, + temperature=0, + regex=build_regex_from_object(Wizard), # Requires pydantic >= 2.0 + ) + + +def driver_character_gen(): + state = character_gen.run(name="Hermione Granger") + print(state.text()) + + +def driver_pydantic_wizard_gen(): + state = pydantic_wizard_gen.run() + print(state.text()) + + +if __name__ == "__main__": + sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) + driver_character_gen() + # driver_pydantic_wizard_gen() diff --git a/examples/usage/openai_speculative.py b/examples/usage/openai_speculative.py index 8fb85255c..cb06428da 100644 --- a/examples/usage/openai_speculative.py +++ b/examples/usage/openai_speculative.py @@ -1,3 +1,7 @@ +""" +Usage: +python3 openai_speculative.py +""" from sglang import function, gen, set_default_backend, OpenAI diff --git a/examples/usage/parallel_sample.py b/examples/usage/parallel_sample.py index ff5a86cbc..288b48ac0 100644 --- a/examples/usage/parallel_sample.py +++ b/examples/usage/parallel_sample.py @@ -1,3 +1,7 @@ +""" +Usage: +python3 parallel_sample.py +""" import sglang as sgl diff --git a/examples/usage/readme_examples.py b/examples/usage/readme_examples.py index 3878f2efc..8789e1b13 100644 --- a/examples/usage/readme_examples.py +++ b/examples/usage/readme_examples.py @@ -1,14 +1,20 @@ +""" +Usage: +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +python readme_examples.py +""" import sglang as sgl @sgl.function def tool_use(s, question): - s += "To answer this question: " + question + ", " - s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". " + s += "To answer this question: " + question + ". " + s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". " + if s["tool"] == "calculator": s += "The math expression is" + sgl.gen("expression") - elif s["tool"] == "web browser": - s += "The website url is" + sgl.gen("url") + elif s["tool"] == "search engine": + s += "The key word to search is" + sgl.gen("word") @sgl.function @@ -28,6 +34,16 @@ def tip_suggestion(s): s += "In summary" + sgl.gen("summary") +@sgl.function +def regular_expression_gen(s): + s += "Q: What is the IP address of the Google DNS servers?\n" + s += "A: " + sgl.gen( + "answer", + temperature=0, + regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)", + ) + + @sgl.function def text_qa(s, question): s += "Q: " + question + "\n" @@ -46,6 +62,12 @@ def driver_tip_suggestion(): print("\n") +def driver_regex(): + state = regular_expression_gen.run() + print(state.text()) + print("\n") + + def driver_batching(): states = text_qa.run_batch( [ @@ -74,9 +96,11 @@ def driver_stream(): if __name__ == "__main__": - sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) + #sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) + sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) driver_tool_use() driver_tip_suggestion() + driver_regex() driver_batching() driver_stream() diff --git a/examples/usage/srt_example_regex.py b/examples/usage/srt_example_regex.py deleted file mode 100644 index 0dcae15ea..000000000 --- a/examples/usage/srt_example_regex.py +++ /dev/null @@ -1,24 +0,0 @@ -from sglang import function, gen, set_default_backend, Runtime - - -IP_ADDR_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" - - -@function -def regex_gen(s): - s += "Q: What is the IP address of the Google DNS servers?\n" - s += "A: " + gen( - "answer", - temperature=0, - regex=IP_ADDR_REGEX, - ) - - -runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") -set_default_backend(runtime) - -state = regex_gen.run() - -print(state.text()) - -runtime.shutdown() diff --git a/examples/usage/streaming.py b/examples/usage/streaming.py index 8ea672417..20feaafbc 100644 --- a/examples/usage/streaming.py +++ b/examples/usage/streaming.py @@ -1,3 +1,7 @@ +""" +Usage: +python3 streaming.py +""" import asyncio import sglang as sgl diff --git a/python/pyproject.toml b/python/pyproject.toml index 4f2035ad1..6a133b54a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", "interegular", "lark", "numba", - "pydantic", "diskcache", "cloudpickle", "pillow"] + "pydantic", "referencing", "diskcache", "cloudpickle", "pillow"] openai = ["openai>=1.0", "numpy"] anthropic = ["anthropic", "numpy"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index a0bed33df..3330f9449 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -30,21 +30,8 @@ def create_logit_bias_int(tokenizer): return mask -CHAT_MODEL_NAMES = [ - # GPT-4 - "gpt-4", - "gpt-4-32k", - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4-0613", - "gpt-4-0314", - # GPT-3.5 - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-16k-0613", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-0301", +INSTRUCT_MODEL_NAMES = [ + "gpt-3.5-turbo-instruct", ] @@ -60,10 +47,10 @@ class OpenAI(BaseBackend): self.tokenizer = tiktoken.encoding_for_model(model_name) self.logit_bias_int = create_logit_bias_int(self.tokenizer) - if model_name in CHAT_MODEL_NAMES: - self.is_chat_model = True - else: + if model_name in INSTRUCT_MODEL_NAMES: self.is_chat_model = False + else: + self.is_chat_model = True self.chat_template = get_chat_template("default") @@ -235,6 +222,8 @@ def openai_completion(client, is_chat=None, prompt=None, **kwargs): def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs): try: if is_chat: + if kwargs["stop"] is None: + kwargs.pop("stop") generator = client.chat.completions.create( messages=prompt, stream=True, **kwargs ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 6b6940d1c..c1ce125ff 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -7,12 +7,19 @@ from sglang.srt.sampling_params import SamplingParams @dataclass class GenerateReqInput: + # The input prompt text: Union[List[str], str] + # The image input image_data: Optional[Union[List[str], str]] = None + # The sampling_params sampling_params: Union[List[Dict], Dict] = None + # The request id rid: Optional[Union[List[str], str]] = None + # Whether return logprobs of the prompts return_logprob: Optional[Union[List[bool], bool]] = None + # The start location of the prompt for return_logprob logprob_start_len: Optional[Union[List[int], int]] = None + # Whether to stream output stream: bool = False def post_init(self): diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index b4252e4db..9fed0b7b3 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True): def test_image_qa(): @sgl.function def image_qa(s, question): - s += sgl.user(sgl.image("image.png") + question) + s += sgl.user(sgl.image("test_image.png") + question) s += sgl.assistant(sgl.gen("answer")) state = image_qa.run( diff --git a/test/srt/test_curl.sh b/test/srt/test_curl.sh new file mode 100644 index 000000000..4362eaa93 --- /dev/null +++ b/test/srt/test_curl.sh @@ -0,0 +1,9 @@ +curl http://localhost:30000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Once upon a time,", + "sampling_params": { + "max_new_tokens": 16, + "temperature": 0 + } + }' diff --git a/test/srt/test_httpserver_llava.py b/test/srt/test_httpserver_llava.py index 25bb79c81..0f6571b45 100644 --- a/test/srt/test_httpserver_llava.py +++ b/test/srt/test_httpserver_llava.py @@ -34,7 +34,7 @@ async def test_concurrent(args): url + "/generate", { "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", - "image_data": "/home/ubuntu/sglang/test/lang/image.png", + "image_data": "test_image.png", "sampling_params": { "temperature": 0, "max_new_tokens": 16, @@ -55,7 +55,7 @@ def test_streaming(args): url + "/generate", json={ "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", - "image_data": "/home/ubuntu/sglang/test/lang/image.png", + "image_data": "test_image.png", "sampling_params": { "temperature": 0, "max_new_tokens": 128,