Add an async example (#37)
This commit is contained in:
36
examples/usage/async.py
Normal file
36
examples/usage/async.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import asyncio
|
||||||
|
from sglang import Runtime
|
||||||
|
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
engine,
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
):
|
||||||
|
tokenizer = engine.get_tokenizer()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You will be given question answer tasks.",},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = engine.add_request(prompt, sampling_params)
|
||||||
|
|
||||||
|
async for output in stream:
|
||||||
|
print(output, end="", flush=True)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
|
||||||
|
print("runtime ready")
|
||||||
|
|
||||||
|
prompt = "Who is Alan Turing?"
|
||||||
|
sampling_params = {"max_new_tokens": 128}
|
||||||
|
asyncio.run(generate(runtime, prompt, sampling_params))
|
||||||
|
|
||||||
|
runtime.shutdown()
|
||||||
@@ -78,8 +78,10 @@ class OpenAI(BaseBackend):
|
|||||||
if sampling_params.dtype is None:
|
if sampling_params.dtype is None:
|
||||||
if self.is_chat_model:
|
if self.is_chat_model:
|
||||||
if not s.text_.endswith("ASSISTANT:"):
|
if not s.text_.endswith("ASSISTANT:"):
|
||||||
raise RuntimeError("This use case is not supported. "
|
raise RuntimeError(
|
||||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant")
|
"This use case is not supported. "
|
||||||
|
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||||
|
)
|
||||||
prompt = s.messages_
|
prompt = s.messages_
|
||||||
else:
|
else:
|
||||||
prompt = s.text_
|
prompt = s.text_
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import List, Optional
|
|||||||
# Fix a Python bug
|
# Fix a Python bug
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import psutil
|
import psutil
|
||||||
import requests
|
import requests
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -25,6 +26,7 @@ from sglang.srt.conversation import (
|
|||||||
generate_chat_conv,
|
generate_chat_conv,
|
||||||
register_conv_template,
|
register_conv_template,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.openai_protocol import (
|
from sglang.srt.managers.openai_protocol import (
|
||||||
@@ -402,7 +404,7 @@ class Runtime:
|
|||||||
):
|
):
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
port = alloc_usable_network_port(1)[0]
|
port = alloc_usable_network_port(1)[0]
|
||||||
server_args = ServerArgs(
|
self.server_args = ServerArgs(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
tokenizer_path=tokenizer_path,
|
tokenizer_path=tokenizer_path,
|
||||||
host=host,
|
host=host,
|
||||||
@@ -417,11 +419,14 @@ class Runtime:
|
|||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
)
|
)
|
||||||
self.url = server_args.url()
|
self.url = self.server_args.url()
|
||||||
|
self.generate_url = (
|
||||||
|
f"http://{self.server_args.host}:{self.server_args.port}/generate"
|
||||||
|
)
|
||||||
|
|
||||||
self.pid = None
|
self.pid = None
|
||||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||||
proc = mp.Process(target=launch_server, args=(server_args, pipe_writer))
|
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
|
||||||
proc.start()
|
proc.start()
|
||||||
self.pid = proc.pid
|
self.pid = proc.pid
|
||||||
|
|
||||||
@@ -443,5 +448,40 @@ class Runtime:
|
|||||||
parent.wait(timeout=5)
|
parent.wait(timeout=5)
|
||||||
self.pid = None
|
self.pid = None
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
return get_tokenizer(
|
||||||
|
self.server_args.tokenizer_path,
|
||||||
|
tokenizer_mode=self.server_args.tokenizer_mode,
|
||||||
|
trust_remote_code=self.server_args.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_request(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
sampling_params,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": sampling_params,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
pos = 0
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
|
async with session.post(self.generate_url, json=json_data) as response:
|
||||||
|
async for chunk, _ in response.content.iter_chunks():
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]\n\n":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
|
cur = data["text"][pos:]
|
||||||
|
if cur:
|
||||||
|
yield cur
|
||||||
|
pos += len(cur)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|||||||
Reference in New Issue
Block a user