From 3f5c2f4c4aa6b8342497b612a3c35b1294bd2314 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sun, 21 Jan 2024 15:17:30 -0800 Subject: [PATCH] Add an async example (#37) --- examples/usage/async.py | 36 ++++++++++++++++++++++++++ python/sglang/backend/openai.py | 6 +++-- python/sglang/srt/server.py | 46 ++++++++++++++++++++++++++++++--- 3 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 examples/usage/async.py diff --git a/examples/usage/async.py b/examples/usage/async.py new file mode 100644 index 000000000..bf5dbd79a --- /dev/null +++ b/examples/usage/async.py @@ -0,0 +1,36 @@ +import asyncio +from sglang import Runtime + + +async def generate( + engine, + prompt, + sampling_params, +): + tokenizer = engine.get_tokenizer() + + messages = [ + {"role": "system", "content": "You will be given question answer tasks.",}, + {"role": "user", "content": prompt}, + ] + + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + stream = engine.add_request(prompt, sampling_params) + + async for output in stream: + print(output, end="", flush=True) + print() + + +if __name__ == "__main__": + runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") + print("runtime ready") + + prompt = "Who is Alan Turing?" + sampling_params = {"max_new_tokens": 128} + asyncio.run(generate(runtime, prompt, sampling_params)) + + runtime.shutdown() diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index d34605ecd..7bd763ce9 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -78,8 +78,10 @@ class OpenAI(BaseBackend): if sampling_params.dtype is None: if self.is_chat_model: if not s.text_.endswith("ASSISTANT:"): - raise RuntimeError("This use case is not supported. " - "For OpenAI chat models, sgl.gen must be right after sgl.assistant") + raise RuntimeError( + "This use case is not supported. " + "For OpenAI chat models, sgl.gen must be right after sgl.assistant" + ) prompt = s.messages_ else: prompt = s.text_ diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c5bbe0674..26e36ca5e 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -11,6 +11,7 @@ from typing import List, Optional # Fix a Python bug setattr(threading, "_register_atexit", lambda *args, **kwargs: None) +import aiohttp import psutil import requests import uvicorn @@ -25,6 +26,7 @@ from sglang.srt.conversation import ( generate_chat_conv, register_conv_template, ) +from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.openai_protocol import ( @@ -402,7 +404,7 @@ class Runtime: ): host = "127.0.0.1" port = alloc_usable_network_port(1)[0] - server_args = ServerArgs( + self.server_args = ServerArgs( model_path=model_path, tokenizer_path=tokenizer_path, host=host, @@ -417,11 +419,14 @@ class Runtime: random_seed=random_seed, log_level=log_level, ) - self.url = server_args.url() + self.url = self.server_args.url() + self.generate_url = ( + f"http://{self.server_args.host}:{self.server_args.port}/generate" + ) self.pid = None pipe_reader, pipe_writer = mp.Pipe(duplex=False) - proc = mp.Process(target=launch_server, args=(server_args, pipe_writer)) + proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer)) proc.start() self.pid = proc.pid @@ -443,5 +448,40 @@ class Runtime: parent.wait(timeout=5) self.pid = None + def get_tokenizer(self): + return get_tokenizer( + self.server_args.tokenizer_path, + tokenizer_mode=self.server_args.tokenizer_mode, + trust_remote_code=self.server_args.trust_remote_code, + ) + + async def add_request( + self, + prompt: str, + sampling_params, + ) -> None: + + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "stream": True, + } + + pos = 0 + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.post(self.generate_url, json=json_data) as response: + async for chunk, _ in response.content.iter_chunks(): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]\n\n": + break + data = json.loads(chunk[5:].strip("\n")) + cur = data["text"][pos:] + if cur: + yield cur + pos += len(cur) + def __del__(self): self.shutdown()