Add an async example (#37)
This commit is contained in:
@@ -78,8 +78,10 @@ class OpenAI(BaseBackend):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if not s.text_.endswith("ASSISTANT:"):
|
||||
raise RuntimeError("This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant")
|
||||
raise RuntimeError(
|
||||
"This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import List, Optional
|
||||
# Fix a Python bug
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
import aiohttp
|
||||
import psutil
|
||||
import requests
|
||||
import uvicorn
|
||||
@@ -25,6 +26,7 @@ from sglang.srt.conversation import (
|
||||
generate_chat_conv,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
@@ -402,7 +404,7 @@ class Runtime:
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
port = alloc_usable_network_port(1)[0]
|
||||
server_args = ServerArgs(
|
||||
self.server_args = ServerArgs(
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
host=host,
|
||||
@@ -417,11 +419,14 @@ class Runtime:
|
||||
random_seed=random_seed,
|
||||
log_level=log_level,
|
||||
)
|
||||
self.url = server_args.url()
|
||||
self.url = self.server_args.url()
|
||||
self.generate_url = (
|
||||
f"http://{self.server_args.host}:{self.server_args.port}/generate"
|
||||
)
|
||||
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
proc = mp.Process(target=launch_server, args=(server_args, pipe_writer))
|
||||
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
|
||||
proc.start()
|
||||
self.pid = proc.pid
|
||||
|
||||
@@ -443,5 +448,40 @@ class Runtime:
|
||||
parent.wait(timeout=5)
|
||||
self.pid = None
|
||||
|
||||
def get_tokenizer(self):
|
||||
return get_tokenizer(
|
||||
self.server_args.tokenizer_path,
|
||||
tokenizer_mode=self.server_args.tokenizer_mode,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params,
|
||||
) -> None:
|
||||
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
pos = 0
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.post(self.generate_url, json=json_data) as response:
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]\n\n":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
cur = data["text"][pos:]
|
||||
if cur:
|
||||
yield cur
|
||||
pos += len(cur)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user