diff --git a/examples/runtime/srt_engine.py b/examples/runtime/srt_engine.py new file mode 100644 index 000000000..7404c7e4e --- /dev/null +++ b/examples/runtime/srt_engine.py @@ -0,0 +1,28 @@ +import sglang as sgl + + +def main(): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = {"temperature": 0.8, "top_p": 0.95} + + # Create an LLM. + llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for prompt, output in zip(prompts, outputs): + print("===============================") + print(f"Prompt: {prompt}\nGenerated text: {output['text']}") + + +# The __main__ condition is necessary here because we use "spawn" to create subprocesses +# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine +if __name__ == "__main__": + main() diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 71d7bfecc..3c4457c98 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -1,6 +1,7 @@ # SGL API Components from sglang.api import ( + Engine, Runtime, assistant, assistant_begin, @@ -31,6 +32,7 @@ from sglang.lang.choices import ( # SGLang DSL APIs __all__ = [ "Runtime", + "Engine", "assistant", "assistant_begin", "assistant_end", diff --git a/python/sglang/api.py b/python/sglang/api.py index 4082deae1..68524363e 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -33,13 +33,23 @@ def function( def Runtime(*args, **kwargs): - # Avoid importing unnecessary dependency os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + # Avoid importing unnecessary dependency from sglang.srt.server import Runtime return Runtime(*args, **kwargs) +def Engine(*args, **kwargs): + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + # Avoid importing unnecessary dependency + from sglang.srt.server import Engine + + return Engine(*args, **kwargs) + + def set_default_backend(backend: BaseBackend): global_config.default_backend = backend diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 0772816c9..c708d6f45 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -19,6 +19,7 @@ SRT = SGLang Runtime. """ import asyncio +import atexit import dataclasses import json import logging @@ -161,6 +162,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): ) +# fastapi implicitly converts json in the request to obj (dataclass) async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" if obj.stream: @@ -290,11 +292,13 @@ async def retrieve_file_content(file_id: str): return await v1_retrieve_file_content(file_id) -def launch_server( +def launch_engine( server_args: ServerArgs, - pipe_finish_writer: Optional[mp.connection.Connection] = None, ): - """Launch an HTTP server.""" + """ + Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess. + """ + global tokenizer_manager # Configure global environment @@ -355,6 +359,29 @@ def launch_server( for i in range(len(scheduler_pipe_readers)): scheduler_pipe_readers[i].recv() + +def launch_server( + server_args: ServerArgs, + pipe_finish_writer: Optional[mp.connection.Connection] = None, +): + """ + Launch SRT (SGLang Runtime) Server + + The SRT server consists of an HTTP server and the SRT engine. + + 1. HTTP server: A FastAPI server that routes requests to the engine. + 2. SRT engine: + 1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server and Tokenizer Manager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + + launch_engine(server_args=server_args) + # Add api key authorization if server_args.api_key: add_api_key_middleware(app, server_args.api_key) @@ -435,7 +462,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): return model_info = res.json() - # Send a warmup request request_name = "/generate" if model_info["is_generation"] else "/encode" max_new_tokens = 8 if model_info["is_generation"] else 1 @@ -626,3 +652,46 @@ class Runtime: def __del__(self): self.shutdown() + + +class Engine: + """ + SRT Engine without an HTTP server layer. + + This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where + launching the HTTP server adds unnecessary complexity or overhead, + """ + + def __init__(self, *args, **kwargs): + + # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() + atexit.register(self.shutdown) + + server_args = ServerArgs(*args, **kwargs) + launch_engine(server_args=server_args) + + def generate( + self, + prompt: Union[str, List[str]], + sampling_params: Optional[Dict] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + ): + obj = GenerateReqInput( + text=prompt, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + ) + + # make it synchronous + return asyncio.run(generate_request(obj, None)) + + def shutdown(self): + kill_child_process(os.getpid(), including_parent=False) + + # TODO (ByronHsu): encode and async generate diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index bfa5f0cc7..6e9aaf960 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ suites = { "test_pytorch_sampling_backend.py", "test_server_args.py", "test_skip_tokenizer_init.py", + "test_srt_engine.py", "test_srt_endpoint.py", "test_torch_compile.py", "test_torchao.py", diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py new file mode 100644 index 000000000..e9e6c9783 --- /dev/null +++ b/test/srt/test_srt_engine.py @@ -0,0 +1,33 @@ +import json +import unittest + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST + + +class TestSRTBackend(unittest.TestCase): + + def test_engine_runtime_consistency(self): + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_MODEL_NAME_FOR_TEST + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + engine = sgl.Engine(model_path=model_path, random_seed=42) + out1 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + runtime = sgl.Runtime(model_path=model_path, random_seed=42) + out2 = json.loads(runtime.generate(prompt, sampling_params))["text"] + runtime.shutdown() + + print("==== Answer 1 ====") + print(out1) + + print("==== Answer 2 ====") + print(out2) + assert out1 == out2, f"{out1} != {out2}" + + +if __name__ == "__main__": + unittest.main()