From c4707f1bb52c1743d1f438940d388ae0da36c92b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 16 Jan 2024 19:53:55 -0800 Subject: [PATCH] Improve docs (#17) --- README.md | 21 ++++++++++++++++--- examples/usage/readme_examples.py | 5 ++++- python/sglang/lang/interpreter.py | 6 +----- .../layers/context_flashattention_nopad.py | 1 - python/sglang/srt/layers/extend_attention.py | 1 - .../sglang/srt/managers/router/model_rpc.py | 2 +- 6 files changed, 24 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index bc2976d23..a420cd2b9 100644 --- a/README.md +++ b/README.md @@ -115,13 +115,14 @@ You can then invoke the function with `run` or `run_batch`. The system will manage the state, chat template, and parallelism for you. ### Control Flow +You can use any Python code within the function body, including control flow, nested function calls, and external libraries. + ```python @sgl.function def control_flow(s, question): s += "To answer this question: " + question + ", " s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". " - # You can use if or nested function calls if s["tool"] == "calculator": s += "The math expression is" + sgl.gen("expression") elif s["tool"] == "web browser": @@ -129,6 +130,9 @@ def control_flow(s, question): ``` ### Parallelism +Use `fork` to launch parallel prompts. +Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel. + ```python @sgl.function def tip_suggestion(s): @@ -137,7 +141,7 @@ def tip_suggestion(s): "1. Balanced Diet. 2. Regular Exercise.\n\n" ) - forks = s.fork(2) # Launch parallel prompts + forks = s.fork(2) for i, f in enumerate(forks): f += f"Now, expand tip {i+1} into a paragraph:\n" f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") @@ -148,6 +152,8 @@ def tip_suggestion(s): ``` ### Multi Modality +Use `sgl.image` to pass an image as input. + ```python @sgl.function def image_qa(s, image_file, question): @@ -156,6 +162,8 @@ def image_qa(s, image_file, question): ``` ### Constrained Decoding +Use `regex=` to specify a regular expression as a decoding constraint. + ```python @sgl.function def regular_expression_gen(s): @@ -168,6 +176,8 @@ def regular_expression_gen(s): ``` ### Batching +Use `run_batch` to run a batch of requests with continuous batching. + ```python @sgl.function def text_qa(s, question): @@ -180,10 +190,13 @@ states = text_qa.run_batch( {"question": "What is the capital of France?"}, {"question": "What is the capital of Japan?"}, ], + progress_bar=True ) ``` ### Streaming +Add `stream=True` to enable streaming. + ```python @sgl.function def text_qa(s, question): @@ -192,7 +205,9 @@ def text_qa(s, question): states = text_qa.run( question="What is the capital of France?", - temperature=0.1) + temperature=0.1, + stream=True +) for out in state.text_iter(): print(out, end="", flush=True) diff --git a/examples/usage/readme_examples.py b/examples/usage/readme_examples.py index d7b446c98..3878f2efc 100644 --- a/examples/usage/readme_examples.py +++ b/examples/usage/readme_examples.py @@ -53,6 +53,7 @@ def driver_batching(): {"question": "What is the capital of France?"}, {"question": "What is the capital of Japan?"}, ], + progress_bar=True ) for s in states: @@ -63,7 +64,9 @@ def driver_batching(): def driver_stream(): state = text_qa.run( question="What is the capital of France?", - temperature=0.1) + temperature=0.1, + stream=True + ) for out in state.text_iter(): print(out, end="", flush=True) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index b9a4c184b..5a6bb72a1 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -632,11 +632,7 @@ class ProgramState: self.stream_executor.end() def __repr__(self) -> str: - msgs = self.messages() - ret = "" - for msg in msgs: - ret += msg["role"] + ":\n" + msg["content"] + "\n" - return ret + return f"ProgramState({self.text()})" class ProgramStateGroup: diff --git a/python/sglang/srt/layers/context_flashattention_nopad.py b/python/sglang/srt/layers/context_flashattention_nopad.py index 657cf9f9b..ef254a1b2 100644 --- a/python/sglang/srt/layers/context_flashattention_nopad.py +++ b/python/sglang/srt/layers/context_flashattention_nopad.py @@ -5,7 +5,6 @@ import triton import triton.language as tl from sglang.srt.utils import wrap_kernel_launcher - CUDA_CAPABILITY = torch.cuda.get_device_capability() diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index d1269e726..0c5cebf5a 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -4,7 +4,6 @@ import triton.language as tl from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.utils import wrap_kernel_launcher - CUDA_CAPABILITY = torch.cuda.get_device_capability() diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index bedd60914..09f98d1dd 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -2,10 +2,10 @@ import asyncio import logging import multiprocessing import time +import warnings from concurrent.futures import ThreadPoolExecutor from enum import Enum, auto from typing import Dict, List, Optional, Tuple, Union -import warnings import numpy as np import rpyc