Add an async example (#37)

This commit is contained in:
Ying Sheng
2024-01-21 15:17:30 -08:00
committed by GitHub
parent 007eeb4eb9
commit 3f5c2f4c4a
3 changed files with 83 additions and 5 deletions

View File

@@ -78,8 +78,10 @@ class OpenAI(BaseBackend):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith("ASSISTANT:"):
raise RuntimeError("This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant")
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
else:
prompt = s.text_

View File

@@ -11,6 +11,7 @@ from typing import List, Optional
# Fix a Python bug
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp
import psutil
import requests
import uvicorn
@@ -25,6 +26,7 @@ from sglang.srt.conversation import (
generate_chat_conv,
register_conv_template,
)
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import (
@@ -402,7 +404,7 @@ class Runtime:
):
host = "127.0.0.1"
port = alloc_usable_network_port(1)[0]
server_args = ServerArgs(
self.server_args = ServerArgs(
model_path=model_path,
tokenizer_path=tokenizer_path,
host=host,
@@ -417,11 +419,14 @@ class Runtime:
random_seed=random_seed,
log_level=log_level,
)
self.url = server_args.url()
self.url = self.server_args.url()
self.generate_url = (
f"http://{self.server_args.host}:{self.server_args.port}/generate"
)
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(target=launch_server, args=(server_args, pipe_writer))
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
proc.start()
self.pid = proc.pid
@@ -443,5 +448,40 @@ class Runtime:
parent.wait(timeout=5)
self.pid = None
def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
)
async def add_request(
self,
prompt: str,
sampling_params,
) -> None:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
def __del__(self):
self.shutdown()