[Session] Update session control interface (#2635)
This commit is contained in:
@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionParams:
|
||||
id: Optional[str] = None
|
||||
rid: Optional[str] = None
|
||||
offset: Optional[int] = None
|
||||
replace: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
@@ -58,10 +66,8 @@ class GenerateReqInput:
|
||||
# LoRA related
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Session id info for continual prompting
|
||||
session: Optional[
|
||||
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
||||
] = None
|
||||
# Session info for continual prompting
|
||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (
|
||||
@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput:
|
||||
# The input embeds
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[str] = None
|
||||
session_rid: Optional[str] = None
|
||||
# Session info for continual prompting
|
||||
session_params: Optional[SessionParams] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -468,6 +473,7 @@ class ProfileReq(Enum):
|
||||
@dataclass
|
||||
class OpenSessionReqInput:
|
||||
capacity_of_str_len: int
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -477,4 +483,5 @@ class CloseSessionReqInput:
|
||||
|
||||
@dataclass
|
||||
class OpenSessionReqOutput:
|
||||
session_id: str
|
||||
session_id: Optional[str]
|
||||
success: bool
|
||||
|
||||
@@ -22,7 +22,7 @@ import warnings
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -498,8 +498,10 @@ class Scheduler:
|
||||
else:
|
||||
self.stop_profile()
|
||||
elif isinstance(recv_req, OpenSessionReqInput):
|
||||
session_id = self.open_session(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
||||
session_id, success = self.open_session(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
OpenSessionReqOutput(session_id=session_id, success=success)
|
||||
)
|
||||
elif isinstance(recv_req, CloseSessionReqInput):
|
||||
self.close_session(recv_req)
|
||||
else:
|
||||
@@ -510,7 +512,11 @@ class Scheduler:
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
# Create a new request
|
||||
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||
if (
|
||||
recv_req.session_params is None
|
||||
or recv_req.session_params.id is None
|
||||
or recv_req.session_params.id not in self.sessions
|
||||
):
|
||||
|
||||
if recv_req.input_embeds is not None:
|
||||
# Generate fake input_ids based on the length of input_embeds
|
||||
@@ -532,15 +538,18 @@ class Scheduler:
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
if recv_req.session_id is not None:
|
||||
if (
|
||||
recv_req.session_params is not None
|
||||
and recv_req.session_params.id is not None
|
||||
):
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
f"Invalid request: session id {recv_req.session_id} does not exist"
|
||||
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
||||
)
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
else:
|
||||
# Create a new request from a previsou session
|
||||
session = self.sessions[recv_req.session_id]
|
||||
# Create a new request from a previous session
|
||||
session = self.sessions[recv_req.session_params.id]
|
||||
req = session.create_req(recv_req, self.tokenizer)
|
||||
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||
self.waiting_queue.append(req)
|
||||
@@ -1500,16 +1509,20 @@ class Scheduler:
|
||||
)
|
||||
logger.info("Profiler is done")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
||||
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
if session_id in self.sessions:
|
||||
logger.warning(f"session id {session_id} already exist, cannot open.")
|
||||
return session_id, False
|
||||
elif session_id is None:
|
||||
logger.warning(f"session id is None, cannot open.")
|
||||
return session_id, False
|
||||
else:
|
||||
self.sessions[session_id] = Session(
|
||||
recv_req.capacity_of_str_len, session_id
|
||||
)
|
||||
return session_id
|
||||
return session_id, True
|
||||
|
||||
def close_session(self, recv_req: CloseSessionReqInput):
|
||||
# handle error
|
||||
|
||||
@@ -10,41 +10,116 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
class SessionReqNode:
|
||||
def __init__(self, req, parent=None, childs=None):
|
||||
self.req = req
|
||||
self.parent = parent
|
||||
if parent is not None:
|
||||
parent.childs.append(self)
|
||||
self.childs = [] if not childs else childs
|
||||
|
||||
def clear_childs(self, req_dict):
|
||||
for req_node in self.childs:
|
||||
req_node.clear(req_dict)
|
||||
self.childs = []
|
||||
|
||||
def clear(self, req_dict):
|
||||
for req_node in self.childs:
|
||||
req_node.clear(req_dict)
|
||||
|
||||
if self.req.finished_reason == None:
|
||||
self.req.to_abort = True
|
||||
del req_dict[self.req.rid]
|
||||
|
||||
def abort(self):
|
||||
if self.req.finished_reason == None:
|
||||
self.req.to_abort = True
|
||||
|
||||
def __str__(self):
|
||||
return self._str_helper(self.req.rid)
|
||||
|
||||
def _str_helper(self, prefix=""):
|
||||
if len(self.childs) == 0:
|
||||
return prefix + "\n"
|
||||
else:
|
||||
origin_prefix = prefix
|
||||
prefix += " -- " + self.childs[0].req.rid
|
||||
ret = self.childs[0]._str_helper(prefix)
|
||||
for child in self.childs[1:]:
|
||||
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
||||
ret += child._str_helper(prefix)
|
||||
return ret
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
||||
def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
|
||||
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
||||
self.capacity_of_str_len = capacity_of_str_len
|
||||
self.reqs: List[Req] = []
|
||||
self.req_nodes: Dict[str, SessionReqNode] = {}
|
||||
|
||||
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
||||
if req.session_rid is not None:
|
||||
while len(self.reqs) > 0:
|
||||
if self.reqs[-1].rid == req.session_rid:
|
||||
break
|
||||
self.reqs = self.reqs[:-1]
|
||||
assert req.session_params is not None
|
||||
session_params = req.session_params
|
||||
|
||||
last_req_node = None
|
||||
last_req = None
|
||||
abort = False
|
||||
if session_params.replace:
|
||||
if session_params.rid is None:
|
||||
for _, req_node in self.req_nodes.items():
|
||||
req_node.clear(self.req_nodes)
|
||||
else:
|
||||
if session_params.rid not in self.req_nodes:
|
||||
abort = True
|
||||
else:
|
||||
last_req_node = self.req_nodes[session_params.rid]
|
||||
last_req_node.abort()
|
||||
last_req = last_req_node.req
|
||||
last_req_node.clear_childs(self.req_nodes)
|
||||
else:
|
||||
self.reqs = []
|
||||
if len(self.reqs) > 0:
|
||||
if session_params.rid is not None:
|
||||
if session_params.rid not in self.req_nodes:
|
||||
abort = True
|
||||
else:
|
||||
last_req_node = self.req_nodes[session_params.rid]
|
||||
last_req = last_req_node.req
|
||||
if not last_req.finished():
|
||||
logging.warning(
|
||||
"The request in a session is appending to a request that hasn't finished."
|
||||
)
|
||||
abort = True
|
||||
|
||||
if last_req is not None:
|
||||
# trim bos token if it is an append
|
||||
if req.input_ids[0] == tokenizer.bos_token_id:
|
||||
req.input_ids = req.input_ids[1:]
|
||||
|
||||
input_ids = (
|
||||
self.reqs[-1].origin_input_ids
|
||||
+ self.reqs[-1].output_ids[
|
||||
: self.reqs[-1].sampling_params.max_new_tokens
|
||||
]
|
||||
+ req.input_ids
|
||||
last_req.origin_input_ids
|
||||
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
|
||||
)
|
||||
if session_params.offset and session_params.offset != 0:
|
||||
input_ids = input_ids[: session_params.offset] + req.input_ids
|
||||
else:
|
||||
input_ids += req.input_ids
|
||||
input_ids_unpadded = (
|
||||
self.reqs[-1].origin_input_ids_unpadded
|
||||
+ self.reqs[-1].output_ids[
|
||||
: self.reqs[-1].sampling_params.max_new_tokens
|
||||
]
|
||||
+ req.input_ids
|
||||
last_req.origin_input_ids_unpadded
|
||||
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
|
||||
)
|
||||
if session_params.offset and session_params.offset != 0:
|
||||
input_ids_unpadded = (
|
||||
input_ids_unpadded[: session_params.offset] + req.input_ids
|
||||
)
|
||||
else:
|
||||
input_ids_unpadded += req.input_ids
|
||||
else:
|
||||
input_ids = req.input_ids
|
||||
input_ids_unpadded = req.input_ids
|
||||
@@ -57,13 +132,13 @@ class Session:
|
||||
lora_path=req.lora_path,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
if len(self.reqs) > 0:
|
||||
new_req.image_inputs = self.reqs[-1].image_inputs
|
||||
if last_req is not None:
|
||||
new_req.image_inputs = last_req.image_inputs
|
||||
new_req.tokenizer = tokenizer
|
||||
if req.session_rid is not None and len(self.reqs) == 0:
|
||||
new_req.finished_reason = FINISH_ABORT(
|
||||
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
||||
)
|
||||
if abort:
|
||||
new_req.to_abort = True
|
||||
else:
|
||||
self.reqs.append(new_req)
|
||||
new_req_node = SessionReqNode(new_req, last_req_node)
|
||||
self.req_nodes[req.rid] = new_req_node
|
||||
|
||||
return new_req
|
||||
|
||||
@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
SessionParams,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
@@ -264,8 +265,9 @@ class TokenizerManager:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
session_id = obj.session[0] if obj.session else None
|
||||
session_rid = obj.session[1] if obj.session else None
|
||||
session_params = (
|
||||
SessionParams(**obj.session_params) if obj.session_params else None
|
||||
)
|
||||
|
||||
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
@@ -292,8 +294,7 @@ class TokenizerManager:
|
||||
obj.stream,
|
||||
lora_path=obj.lora_path,
|
||||
input_embeds=input_embeds,
|
||||
session_id=session_id,
|
||||
session_rid=session_rid,
|
||||
session_params=session_params,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -552,12 +553,16 @@ class TokenizerManager:
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
session_id = uuid.uuid4().hex
|
||||
obj.session_id = session_id
|
||||
if obj.session_id is None:
|
||||
obj.session_id = uuid.uuid4().hex
|
||||
elif obj.session_id in self.session_futures:
|
||||
return None
|
||||
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.session_futures[session_id] = asyncio.Future()
|
||||
session_id = await self.session_futures[session_id]
|
||||
del self.session_futures[session_id]
|
||||
|
||||
self.session_futures[obj.session_id] = asyncio.Future()
|
||||
session_id = await self.session_futures[obj.session_id]
|
||||
del self.session_futures[obj.session_id]
|
||||
return session_id
|
||||
|
||||
async def close_session(
|
||||
@@ -709,7 +714,7 @@ class TokenizerManager:
|
||||
)
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
recv_obj.session_id if recv_obj.success else None
|
||||
)
|
||||
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
|
||||
@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
try:
|
||||
session_id = await tokenizer_manager.open_session(obj, request)
|
||||
if session_id is None:
|
||||
raise Exception(
|
||||
"Failed to open the session. Check if a session with the same id is still open."
|
||||
)
|
||||
return session_id
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control_with_branching
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control_backtrack_with_abort
|
||||
python3 -m unittest test_session_control.TestSessionControlVision.test_session_control
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
@@ -18,6 +23,10 @@ from sglang.test.test_utils import (
|
||||
)
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str) -> str:
|
||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
|
||||
|
||||
class TestSessionControl(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -31,15 +40,18 @@ class TestSessionControl(unittest.TestCase):
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_session_control(self):
|
||||
def test_session_control(self, gen_len=12):
|
||||
chunks = [
|
||||
"Let me tell you something about France.",
|
||||
"The capital of France is",
|
||||
"The population of the city is",
|
||||
"A brief history about that city is",
|
||||
"To plan a travel, the budget is",
|
||||
]
|
||||
tokenizer = get_tokenizer(self.model)
|
||||
chunks_ids = [tokenizer.encode(x) for x in chunks]
|
||||
for i in range(1, len(chunks_ids)):
|
||||
if chunks_ids[i][0] == tokenizer.bos_token_id:
|
||||
chunks_ids[i] = chunks_ids[i][1:]
|
||||
|
||||
# 1. using session control
|
||||
session_id = requests.post(
|
||||
@@ -48,6 +60,13 @@ class TestSessionControl(unittest.TestCase):
|
||||
).json()
|
||||
rid = None
|
||||
|
||||
# open an existing session, should get session_id as None
|
||||
response = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000, "session_id": session_id},
|
||||
).json()
|
||||
assert isinstance(response, dict) and "error" in response
|
||||
|
||||
first_rid = None
|
||||
outputs_from_session = []
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
@@ -55,11 +74,16 @@ class TestSessionControl(unittest.TestCase):
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunk_ids,
|
||||
"session": [session_id, rid],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid,
|
||||
"offset": -1,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
gen_len if i > 0 else 1
|
||||
), # prefill only for the first chunk
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
@@ -77,10 +101,15 @@ class TestSessionControl(unittest.TestCase):
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, first_rid],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": first_rid,
|
||||
"offset": -1,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
@@ -93,10 +122,15 @@ class TestSessionControl(unittest.TestCase):
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, rid],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid,
|
||||
"offset": -1,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
@@ -115,10 +149,15 @@ class TestSessionControl(unittest.TestCase):
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, first_rid],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": first_rid,
|
||||
"offset": -1,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
@@ -127,6 +166,8 @@ class TestSessionControl(unittest.TestCase):
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
|
||||
# 2. not use session control
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
|
||||
input_ids_first_req = None
|
||||
input_ids = []
|
||||
outputs_normal = []
|
||||
@@ -139,7 +180,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
gen_len if i > 0 else 1
|
||||
), # prefill only for the first chunk
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
@@ -150,7 +191,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
output_ids = tokenizer.encode(response["text"])
|
||||
if output_ids[0] == tokenizer.bos_token_id:
|
||||
output_ids = output_ids[1:]
|
||||
input_ids += output_ids
|
||||
input_ids += output_ids[:-1]
|
||||
outputs_normal.append(response["text"])
|
||||
if i == 0:
|
||||
input_ids_first_req = input_ids.copy()
|
||||
@@ -162,7 +203,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
"input_ids": input_ids_first_req,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
@@ -176,6 +217,272 @@ class TestSessionControl(unittest.TestCase):
|
||||
print(outputs_normal)
|
||||
assert outputs_from_session == outputs_normal
|
||||
|
||||
async def async_generate(self, payload):
|
||||
url = self.base_url + "/generate"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url=url, json=payload) as response:
|
||||
assert response.status == 200
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
if chunk == "[DONE]":
|
||||
yield "", None, ""
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
finish_reason = (
|
||||
data["meta_info"]["finish_reason"]["type"]
|
||||
if data["meta_info"]["finish_reason"]
|
||||
else ""
|
||||
)
|
||||
yield data["text"], data["meta_info"]["id"], finish_reason
|
||||
|
||||
async def run_session_control_backtrack_with_abort(self, replace):
|
||||
chunks = [
|
||||
"Let me tell you something about France.",
|
||||
"The capital of France is",
|
||||
]
|
||||
tokenizer = get_tokenizer(self.model)
|
||||
chunks_ids = [tokenizer.encode(x) for x in chunks]
|
||||
for i in range(1, len(chunks_ids)):
|
||||
if chunks_ids[i][0] == tokenizer.bos_token_id:
|
||||
chunks_ids[i] = chunks_ids[i][1:]
|
||||
|
||||
# 1. using session control
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
).json()
|
||||
rid = None
|
||||
|
||||
payload = {
|
||||
"input_ids": chunks_ids[0],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid,
|
||||
"offset": -1,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 100,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"stream": True,
|
||||
}
|
||||
gen_so_far = ""
|
||||
finish_reason = ""
|
||||
second_output = ""
|
||||
async for chunk, rid, finish_reason_chunk in self.async_generate(payload):
|
||||
gen_so_far += chunk
|
||||
if finish_reason == "":
|
||||
finish_reason = finish_reason_chunk
|
||||
if len(gen_so_far) > 50 and second_output == "":
|
||||
payload2 = {
|
||||
"input_ids": chunks_ids[1],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid,
|
||||
"offset": 50,
|
||||
"replace": replace,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"stream": False,
|
||||
"stream_output": True,
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.base_url + "/generate", json=payload2
|
||||
).json()
|
||||
second_output = response["text"]
|
||||
if replace:
|
||||
assert finish_reason == "abort"
|
||||
print("first request output:")
|
||||
print(gen_so_far)
|
||||
print("second request output:")
|
||||
print(second_output)
|
||||
|
||||
# close the session
|
||||
ret = requests.post(
|
||||
self.base_url + "/close_session",
|
||||
json={"session_id": session_id},
|
||||
)
|
||||
assert ret.status_code == 200
|
||||
|
||||
if not replace:
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
else:
|
||||
# 2. not using session control
|
||||
output_ids = tokenizer.encode(gen_so_far)
|
||||
if output_ids[0] == tokenizer.bos_token_id:
|
||||
output_ids = output_ids[1:]
|
||||
input_ids = chunks_ids[0] + output_ids
|
||||
input_ids = input_ids[:50] + chunks_ids[1]
|
||||
payload = {
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"stream": False,
|
||||
"stream_output": True,
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.base_url + "/generate", json=payload
|
||||
).json()
|
||||
output_no_session = response["text"]
|
||||
print("second request output without session:")
|
||||
print(output_no_session)
|
||||
assert second_output == output_no_session
|
||||
|
||||
def test_session_control_backtrack_with_abort(self):
|
||||
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
|
||||
asyncio.run(self.run_session_control_backtrack_with_abort(replace=False))
|
||||
|
||||
def run_session_control_with_branching(
|
||||
self, root_prompt, chunks_per_step, gen_len=16
|
||||
):
|
||||
for x in chunks_per_step:
|
||||
assert len(x) == len(chunks_per_step[0])
|
||||
|
||||
# 1. using session control
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
).json()
|
||||
|
||||
outputs_from_session = []
|
||||
# send the root prompt
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": root_prompt,
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": None,
|
||||
"offset": 0,
|
||||
"replace": False,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
rid_per_branch = [response["meta_info"]["id"]] * len(chunks_per_step[0])
|
||||
outputs_from_session.append(response["text"])
|
||||
|
||||
# send the prompts in branches
|
||||
for chunks_for_branches in chunks_per_step:
|
||||
for j, chunk in enumerate(chunks_for_branches):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": chunk,
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid_per_branch[j],
|
||||
"offset": 0,
|
||||
"replace": False,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
rid = response["meta_info"]["id"]
|
||||
rid_per_branch[j] = rid
|
||||
outputs_from_session.append(response["text"])
|
||||
|
||||
# close the session
|
||||
ret = requests.post(
|
||||
self.base_url + "/close_session",
|
||||
json={"session_id": session_id},
|
||||
)
|
||||
assert ret.status_code == 200
|
||||
|
||||
# 2. not use session control
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
|
||||
outputs_normal = []
|
||||
input_texts = [root_prompt] * len(chunks_per_step[0])
|
||||
# send the root prompt
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": root_prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
outputs_normal.append(response["text"])
|
||||
input_texts = [x + response["text"] for x in input_texts]
|
||||
|
||||
# send the prompts in branches
|
||||
for chunks_for_branches in chunks_per_step:
|
||||
for j, chunk in enumerate(chunks_for_branches):
|
||||
input_texts[j] += chunk
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": input_texts[j],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
outputs_normal.append(response["text"])
|
||||
input_texts[j] += response["text"]
|
||||
|
||||
print("====== outputs from chunked queries with session control: =======")
|
||||
print(outputs_from_session)
|
||||
print("====== outputs from normal queries: =======")
|
||||
print(outputs_normal)
|
||||
assert outputs_from_session == outputs_normal
|
||||
|
||||
def test_session_control_with_branching(self):
|
||||
root_prompt = "First, let me explain in one sentence about AI"
|
||||
chunks_per_step = [
|
||||
[
|
||||
"Then, briefly, the positive side of AI is",
|
||||
"But, briefly, AI could be harmful to human",
|
||||
],
|
||||
["For example", "For example"],
|
||||
]
|
||||
self.run_session_control_with_branching(
|
||||
root_prompt=root_prompt, chunks_per_step=chunks_per_step, gen_len=8
|
||||
)
|
||||
|
||||
root_prompt = "I have three apples."
|
||||
chunks_per_step = [
|
||||
["I then give one apple to my friend", "My friend give me another apple."],
|
||||
["I still have", "I now have"],
|
||||
]
|
||||
self.run_session_control_with_branching(
|
||||
root_prompt=root_prompt, chunks_per_step=chunks_per_step, gen_len=8
|
||||
)
|
||||
|
||||
|
||||
class TestSessionControlVision(unittest.TestCase):
|
||||
@classmethod
|
||||
@@ -197,17 +504,25 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
text_chunks = [
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
|
||||
"<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nIs this image same with one of the previous images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nIs this image same with one of the previous images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\nDescribe this image in a very short sentence.<|im_end|>\nassistant:",
|
||||
]
|
||||
image_chunks = [
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
]
|
||||
assert len(text_chunks) == len(image_chunks) + 1
|
||||
|
||||
assert (
|
||||
len(text_chunks) == len(image_chunks) + 2
|
||||
) # the first and the last prompt does not contain images
|
||||
tokenizer = get_tokenizer(self.model)
|
||||
text_input_ids = [tokenizer.encode(x) for x in text_chunks]
|
||||
for i in range(1, len(text_input_ids)):
|
||||
if text_input_ids[i][0] == tokenizer.bos_token_id:
|
||||
text_input_ids[i] = text_input_ids[i][1:]
|
||||
gen_len = 32
|
||||
|
||||
# 1. using session control
|
||||
session_id = requests.post(
|
||||
@@ -216,20 +531,32 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
).json()
|
||||
rid = None
|
||||
|
||||
# open an existing session, should get session_id as None
|
||||
response = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000, "session_id": session_id},
|
||||
).json()
|
||||
assert isinstance(response, dict) and "error" in response
|
||||
|
||||
first_rid = None
|
||||
outputs_from_session = []
|
||||
for i in range(len(text_input_ids)):
|
||||
for i in range(len(text_input_ids[:-1])):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": text_input_ids[i],
|
||||
"image_data": image_chunks[i - 1] if i > 0 else None,
|
||||
"modalities": ["multi-images"],
|
||||
"session": [session_id, rid],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid,
|
||||
"offset": 0,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
gen_len if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
@@ -247,12 +574,15 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": text_input_ids[-1],
|
||||
"image_data": image_chunks[-1:],
|
||||
"modalities": ["multi-images"],
|
||||
"session": [session_id, first_rid],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": first_rid,
|
||||
"offset": 0,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
@@ -265,12 +595,15 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": text_input_ids[-1],
|
||||
"image_data": image_chunks[-1:],
|
||||
"modalities": ["multi-images"],
|
||||
"session": [session_id, rid],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid,
|
||||
"offset": 0,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
@@ -289,10 +622,15 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": text_input_ids[-1],
|
||||
"session": [session_id, first_rid],
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": first_rid,
|
||||
"offset": 0,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
@@ -306,7 +644,7 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
input_ids_first_req = None
|
||||
input_ids = []
|
||||
outputs_normal = []
|
||||
for i in range(len(text_input_ids)):
|
||||
for i in range(len(text_input_ids[:-1])):
|
||||
input_ids += text_input_ids[i]
|
||||
image_data = image_chunks[:i] if i > 0 else None
|
||||
response = requests.post(
|
||||
@@ -318,7 +656,7 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
gen_len if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
@@ -339,11 +677,9 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids_first_req,
|
||||
"image_data": image_chunks[-1:],
|
||||
"modalities": ["multi-images"],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": gen_len,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user