diff --git a/examples/runtime/engine/custom_server.py b/examples/runtime/engine/custom_server.py new file mode 100644 index 000000000..b190a463e --- /dev/null +++ b/examples/runtime/engine/custom_server.py @@ -0,0 +1,53 @@ +from sanic import Sanic, text +from sanic.response import json + +import sglang as sgl + +engine = None + +# Create an instance of the Sanic app +app = Sanic("sanic-server") + + +# Define an asynchronous route handler +@app.route("/generate", methods=["POST"]) +async def generate(request): + prompt = request.json.get("prompt") + if not prompt: + return json({"error": "Prompt is required"}, status=400) + + # async_generate returns a dict + result = await engine.async_generate(prompt) + + return text(result["text"]) + + +@app.route("/generate_stream", methods=["POST"]) +async def generate_stream(request): + prompt = request.json.get("prompt") + + if not prompt: + return json({"error": "Prompt is required"}, status=400) + + # async_generate returns a dict + result = await engine.async_generate(prompt, stream=True) + + # https://sanic.dev/en/guide/advanced/streaming.md#streaming + # init the response + response = await request.respond() + + # result is an async generator + async for chunk in result: + await response.send(chunk["text"]) + + await response.eof() + + +def run_server(): + global engine + engine = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + app.run(host="0.0.0.0", port=8000, single_process=True) + + +if __name__ == "__main__": + run_server() diff --git a/examples/runtime/srt_engine.py b/examples/runtime/engine/offline_batch_inference.py similarity index 100% rename from examples/runtime/srt_engine.py rename to examples/runtime/engine/offline_batch_inference.py diff --git a/examples/runtime/engine/readme.md b/examples/runtime/engine/readme.md new file mode 100644 index 000000000..14209ebd6 --- /dev/null +++ b/examples/runtime/engine/readme.md @@ -0,0 +1,40 @@ +# SGLang Engine + +## Introduction +SGLang provides a direct inference engine without the need for an HTTP server. There are generally two use cases: + +1. **Offline Batch Inference** +2. **Custom Server on Top of the Engine** + +## Examples + +### 1. [Offline Batch Inference](./offline_batch_inference.py) + +In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors. + +### 2. [Custom Server](./custom_server.py) + +This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints. + +#### Steps: + +1. Install Sanic: + +```bash +pip install sanic +``` + +2. Run the server: + +```bash +python custom_server +``` + +3. Send requests: + +```bash +curl -X POST http://localhost:8000/generate -H "Content-Type: application/json" -d '{"prompt": "The Transformer architecture is..."}' +curl -X POST http://localhost:8000/generate_stream -H "Content-Type: application/json" -d '{"prompt": "The Transformer architecture is..."}' --no-buffer +``` + +This will send both non-streaming and streaming requests to the server. \ No newline at end of file diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 233c6d29c..46f19d567 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -716,6 +716,58 @@ class Engine: logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, + stream: bool = False, + ): + # TODO (ByronHsu): refactor to reduce the duplicated code + + obj = GenerateReqInput( + text=prompt, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + ) + + # get the current event loop + loop = asyncio.get_event_loop() + ret = loop.run_until_complete(generate_request(obj, None)) + + if stream is True: + STREAM_END_SYMBOL = "data: [DONE]" + STREAM_CHUNK_START_SYMBOL = "data:" + + def generator_wrapper(): + offset = 0 + loop = asyncio.get_event_loop() + generator = ret.body_iterator + while True: + chunk = loop.run_until_complete(generator.__anext__()) + + if chunk.startswith(STREAM_END_SYMBOL): + break + else: + data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) + data["text"] = data["text"][offset:] + offset += len(data["text"]) + yield data + + # we cannot yield in the scope of generate() because python does not allow yield + return in the same function + # however, it allows to wrap the generator as a subfunction and return + return generator_wrapper() + else: + return ret + + async def async_generate( + self, + prompt: Union[str, List[str]], + sampling_params: Optional[Dict] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + stream: bool = False, ): obj = GenerateReqInput( text=prompt, @@ -724,13 +776,37 @@ class Engine: logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, lora_path=lora_path, + stream=stream, ) - # get the current event loop - loop = asyncio.get_event_loop() - return loop.run_until_complete(generate_request(obj, None)) + ret = await generate_request(obj, None) + + if stream is True: + STREAM_END_SYMBOL = "data: [DONE]" + STREAM_CHUNK_START_SYMBOL = "data:" + + generator = ret.body_iterator + + async def generator_wrapper(): + + offset = 0 + + while True: + chunk = await generator.__anext__() + + if chunk.startswith(STREAM_END_SYMBOL): + break + else: + data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) + data["text"] = data["text"][offset:] + offset += len(data["text"]) + yield data + + return generator_wrapper() + else: + return ret def shutdown(self): kill_child_process(os.getpid(), including_parent=False) - # TODO (ByronHsu): encode and async generate + # TODO (ByronHsu): encode diff --git a/python/sglang/test/few_shot_gsm8k_engine.py b/python/sglang/test/few_shot_gsm8k_engine.py new file mode 100644 index 000000000..67844e2f1 --- /dev/null +++ b/python/sglang/test/few_shot_gsm8k_engine.py @@ -0,0 +1,144 @@ +import argparse +import ast +import asyncio +import json +import re +import time + +import numpy as np + +import sglang as sgl +from sglang.api import set_default_backend +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl + +INVALID = -9999999 + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +async def concurrent_generate(engine, prompts, sampling_param): + tasks = [] + for prompt in prompts: + tasks.append(asyncio.create_task(engine.async_generate(prompt, sampling_param))) + + outputs = await asyncio.gather(*tasks) + return outputs + + +def run_eval(args): + # Select backend + engine = sgl.Engine(model_path=args.model_path, log_level="error") + + if args.local_data_path is None: + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + else: + filename = args.local_data_path + + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + questions = [] + labels = [] + for i in range(len(lines[:num_questions])): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q} for q in questions] + + # construct the prompts + prompts = [] + for i, arg in enumerate(arguments): + q = arg["question"] + prompt = few_shot_examples + q + prompts.append(prompt) + + sampling_param = { + "stop": ["Question", "Assistant:", "<|separator|>"], + "max_new_tokens": 512, + "temperature": 0, + } + + # Run requests + tic = time.time() + + loop = asyncio.get_event_loop() + + outputs = loop.run_until_complete( + concurrent_generate(engine, prompts, sampling_param) + ) + + # End requests + latency = time.time() - tic + + # Shutdown the engine + engine.shutdown() + + # Parse output + preds = [] + + for output in outputs: + preds.append(get_answer_value(output["text"])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Compute speed + num_output_tokens = sum( + output["meta_info"]["completion_tokens"] for output in outputs + ) + output_throughput = num_output_tokens / latency + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + print(f"Output throughput: {output_throughput:.3f} token/s") + + return { + "accuracy": acc, + "latency": latency, + "output_throughput": output_throughput, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct" + ) + parser.add_argument("--local-data-path", type=Optional[str], default=None) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + args = parser.parse_args() + metrics = run_eval(args) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 6e9aaf960..ffdaf0fe4 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -77,5 +77,7 @@ if __name__ == "__main__": files = files[args.range_begin : args.range_end] + print("The running tests are ", files) + exit_code = run_unittest_files(files, args.timeout_per_file) exit(exit_code) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index d1ecd61fc..5219ef90f 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -1,19 +1,22 @@ +import asyncio import json import unittest +from types import SimpleNamespace import sglang as sgl +from sglang.test.few_shot_gsm8k_engine import run_eval from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST class TestSRTBackend(unittest.TestCase): - def test_engine_runtime_consistency(self): + def test_1_engine_runtime_consistency(self): prompt = "Today is a sunny day and I like" model_path = DEFAULT_MODEL_NAME_FOR_TEST sampling_params = {"temperature": 0, "max_new_tokens": 8} - engine = sgl.Engine(model_path=model_path, random_seed=42) + engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") out1 = engine.generate(prompt, sampling_params)["text"] engine.shutdown() @@ -28,18 +31,76 @@ class TestSRTBackend(unittest.TestCase): print(out2) assert out1 == out2, f"{out1} != {out2}" - def test_engine_multiple_generate(self): + def test_2_engine_multiple_generate(self): # just to ensure there is no issue running multiple generate calls prompt = "Today is a sunny day and I like" model_path = DEFAULT_MODEL_NAME_FOR_TEST sampling_params = {"temperature": 0, "max_new_tokens": 8} - engine = sgl.Engine(model_path=model_path, random_seed=42) + engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") engine.generate(prompt, sampling_params) engine.generate(prompt, sampling_params) engine.shutdown() + def test_3_sync_streaming_combination(self): + + prompt = "AI safety is..." + sampling_params = {"temperature": 0.8, "top_p": 0.95} + + async def async_streaming(engine): + + generator = await engine.async_generate( + prompt, sampling_params, stream=True + ) + + async for output in generator: + print(output["text"], end="", flush=True) + print() + + # Create an LLM. + llm = sgl.Engine( + model_path=DEFAULT_MODEL_NAME_FOR_TEST, + log_level="error", + ) + + # 1. sync + non streaming + print("\n\n==== 1. sync + non streaming ====") + output = llm.generate(prompt, sampling_params) + + print(output["text"]) + + # 2. sync + streaming + print("\n\n==== 2. sync + streaming ====") + output_generator = llm.generate(prompt, sampling_params, stream=True) + for output in output_generator: + print(output["text"], end="", flush=True) + print() + + loop = asyncio.get_event_loop() + # 3. async + non_streaming + print("\n\n==== 3. async + non streaming ====") + output = loop.run_until_complete(llm.async_generate(prompt, sampling_params)) + print(output["text"]) + + # 4. async + streaming + print("\n\n==== 4. async + streaming ====") + loop.run_until_complete(async_streaming(llm)) + + llm.shutdown() + + def test_4_gsm8k(self): + + args = SimpleNamespace( + model_path=DEFAULT_MODEL_NAME_FOR_TEST, + local_data_path=None, + num_shots=5, + num_questions=200, + ) + + metrics = run_eval(args) + assert metrics["accuracy"] > 0.7 + if __name__ == "__main__": unittest.main()