diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 202112cc8..16db89a0a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams +@dataclass +class SessionParams: + id: Optional[str] = None + rid: Optional[str] = None + offset: Optional[int] = None + replace: Optional[bool] = None + + @dataclass class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. @@ -58,10 +66,8 @@ class GenerateReqInput: # LoRA related lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - # Session id info for continual prompting - session: Optional[ - Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]] - ] = None + # Session info for continual prompting + session_params: Optional[Union[List[Dict], Dict]] = None def normalize_batch_and_arguments(self): if ( @@ -223,9 +229,8 @@ class TokenizedGenerateReqInput: # The input embeds input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None - # Session id info for continual prompting - session_id: Optional[str] = None - session_rid: Optional[str] = None + # Session info for continual prompting + session_params: Optional[SessionParams] = None @dataclass @@ -468,6 +473,7 @@ class ProfileReq(Enum): @dataclass class OpenSessionReqInput: capacity_of_str_len: int + session_id: Optional[str] = None @dataclass @@ -477,4 +483,5 @@ class CloseSessionReqInput: @dataclass class OpenSessionReqOutput: - session_id: str + session_id: Optional[str] + success: bool diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fe7bf0198..1f8207edc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -22,7 +22,7 @@ import warnings from collections import deque from concurrent import futures from types import SimpleNamespace -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import psutil import setproctitle @@ -498,8 +498,10 @@ class Scheduler: else: self.stop_profile() elif isinstance(recv_req, OpenSessionReqInput): - session_id = self.open_session(recv_req) - self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id)) + session_id, success = self.open_session(recv_req) + self.send_to_tokenizer.send_pyobj( + OpenSessionReqOutput(session_id=session_id, success=success) + ) elif isinstance(recv_req, CloseSessionReqInput): self.close_session(recv_req) else: @@ -510,7 +512,11 @@ class Scheduler: recv_req: TokenizedGenerateReqInput, ): # Create a new request - if recv_req.session_id is None or recv_req.session_id not in self.sessions: + if ( + recv_req.session_params is None + or recv_req.session_params.id is None + or recv_req.session_params.id not in self.sessions + ): if recv_req.input_embeds is not None: # Generate fake input_ids based on the length of input_embeds @@ -532,15 +538,18 @@ class Scheduler: ) req.tokenizer = self.tokenizer - if recv_req.session_id is not None: + if ( + recv_req.session_params is not None + and recv_req.session_params.id is not None + ): req.finished_reason = FINISH_ABORT( - f"Invalid request: session id {recv_req.session_id} does not exist" + f"Invalid request: session id {recv_req.session_params.id} does not exist" ) self.waiting_queue.append(req) return else: - # Create a new request from a previsou session - session = self.sessions[recv_req.session_id] + # Create a new request from a previous session + session = self.sessions[recv_req.session_params.id] req = session.create_req(recv_req, self.tokenizer) if isinstance(req.finished_reason, FINISH_ABORT): self.waiting_queue.append(req) @@ -1500,16 +1509,20 @@ class Scheduler: ) logger.info("Profiler is done") - def open_session(self, recv_req: OpenSessionReqInput) -> str: + def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]: # handle error session_id = recv_req.session_id if session_id in self.sessions: logger.warning(f"session id {session_id} already exist, cannot open.") + return session_id, False + elif session_id is None: + logger.warning(f"session id is None, cannot open.") + return session_id, False else: self.sessions[session_id] = Session( recv_req.capacity_of_str_len, session_id ) - return session_id + return session_id, True def close_session(self, recv_req: CloseSessionReqInput): # handle error diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index dc5a1b670..e3e94ce6b 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -10,41 +10,116 @@ # limitations under the License. # ============================================================================== +import logging import uuid +from typing import Dict, Optional from sglang.srt.managers.io_struct import TokenizedGenerateReqInput -from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req +from sglang.srt.managers.schedule_batch import Req + + +class SessionReqNode: + def __init__(self, req, parent=None, childs=None): + self.req = req + self.parent = parent + if parent is not None: + parent.childs.append(self) + self.childs = [] if not childs else childs + + def clear_childs(self, req_dict): + for req_node in self.childs: + req_node.clear(req_dict) + self.childs = [] + + def clear(self, req_dict): + for req_node in self.childs: + req_node.clear(req_dict) + + if self.req.finished_reason == None: + self.req.to_abort = True + del req_dict[self.req.rid] + + def abort(self): + if self.req.finished_reason == None: + self.req.to_abort = True + + def __str__(self): + return self._str_helper(self.req.rid) + + def _str_helper(self, prefix=""): + if len(self.childs) == 0: + return prefix + "\n" + else: + origin_prefix = prefix + prefix += " -- " + self.childs[0].req.rid + ret = self.childs[0]._str_helper(prefix) + for child in self.childs[1:]: + prefix = " " * len(origin_prefix) + " \- " + child.req.rid + ret += child._str_helper(prefix) + return ret class Session: - def __init__(self, capacity_of_str_len: int, session_id: str = None): + def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None): self.session_id = session_id if session_id is not None else uuid.uuid4().hex self.capacity_of_str_len = capacity_of_str_len - self.reqs: List[Req] = [] + self.req_nodes: Dict[str, SessionReqNode] = {} def create_req(self, req: TokenizedGenerateReqInput, tokenizer): - if req.session_rid is not None: - while len(self.reqs) > 0: - if self.reqs[-1].rid == req.session_rid: - break - self.reqs = self.reqs[:-1] + assert req.session_params is not None + session_params = req.session_params + + last_req_node = None + last_req = None + abort = False + if session_params.replace: + if session_params.rid is None: + for _, req_node in self.req_nodes.items(): + req_node.clear(self.req_nodes) + else: + if session_params.rid not in self.req_nodes: + abort = True + else: + last_req_node = self.req_nodes[session_params.rid] + last_req_node.abort() + last_req = last_req_node.req + last_req_node.clear_childs(self.req_nodes) else: - self.reqs = [] - if len(self.reqs) > 0: + if session_params.rid is not None: + if session_params.rid not in self.req_nodes: + abort = True + else: + last_req_node = self.req_nodes[session_params.rid] + last_req = last_req_node.req + if not last_req.finished(): + logging.warning( + "The request in a session is appending to a request that hasn't finished." + ) + abort = True + + if last_req is not None: + # trim bos token if it is an append + if req.input_ids[0] == tokenizer.bos_token_id: + req.input_ids = req.input_ids[1:] + input_ids = ( - self.reqs[-1].origin_input_ids - + self.reqs[-1].output_ids[ - : self.reqs[-1].sampling_params.max_new_tokens - ] - + req.input_ids + last_req.origin_input_ids + + last_req.output_ids[: last_req.sampling_params.max_new_tokens] ) + if session_params.offset and session_params.offset != 0: + input_ids = input_ids[: session_params.offset] + req.input_ids + else: + input_ids += req.input_ids input_ids_unpadded = ( - self.reqs[-1].origin_input_ids_unpadded - + self.reqs[-1].output_ids[ - : self.reqs[-1].sampling_params.max_new_tokens - ] - + req.input_ids + last_req.origin_input_ids_unpadded + + last_req.output_ids[: last_req.sampling_params.max_new_tokens] ) + if session_params.offset and session_params.offset != 0: + input_ids_unpadded = ( + input_ids_unpadded[: session_params.offset] + req.input_ids + ) + else: + input_ids_unpadded += req.input_ids else: input_ids = req.input_ids input_ids_unpadded = req.input_ids @@ -57,13 +132,13 @@ class Session: lora_path=req.lora_path, session_id=self.session_id, ) - if len(self.reqs) > 0: - new_req.image_inputs = self.reqs[-1].image_inputs + if last_req is not None: + new_req.image_inputs = last_req.image_inputs new_req.tokenizer = tokenizer - if req.session_rid is not None and len(self.reqs) == 0: - new_req.finished_reason = FINISH_ABORT( - f"Invalid request: requested session rid {req.session_rid} does not exist in the session history" - ) + if abort: + new_req.to_abort = True else: - self.reqs.append(new_req) + new_req_node = SessionReqNode(new_req, last_req_node) + self.req_nodes[req.rid] = new_req_node + return new_req diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b98bd09fc..1c81f5e50 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import ( OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + SessionParams, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, @@ -264,8 +265,9 @@ class TokenizerManager: return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len top_logprobs_num = obj.top_logprobs_num - session_id = obj.session[0] if obj.session else None - session_rid = obj.session[1] if obj.session else None + session_params = ( + SessionParams(**obj.session_params) if obj.session_params else None + ) if obj.input_ids is not None and len(input_ids) >= self.context_len: raise ValueError( @@ -292,8 +294,7 @@ class TokenizerManager: obj.stream, lora_path=obj.lora_path, input_embeds=input_embeds, - session_id=session_id, - session_rid=session_rid, + session_params=session_params, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( @@ -552,12 +553,16 @@ class TokenizerManager: ): self.auto_create_handle_loop() - session_id = uuid.uuid4().hex - obj.session_id = session_id + if obj.session_id is None: + obj.session_id = uuid.uuid4().hex + elif obj.session_id in self.session_futures: + return None + self.send_to_scheduler.send_pyobj(obj) - self.session_futures[session_id] = asyncio.Future() - session_id = await self.session_futures[session_id] - del self.session_futures[session_id] + + self.session_futures[obj.session_id] = asyncio.Future() + session_id = await self.session_futures[obj.session_id] + del self.session_futures[obj.session_id] return session_id async def close_session( @@ -709,7 +714,7 @@ class TokenizerManager: ) elif isinstance(recv_obj, OpenSessionReqOutput): self.session_futures[recv_obj.session_id].set_result( - recv_obj.session_id + recv_obj.session_id if recv_obj.success else None ) elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput): if self.server_args.dp_size == 1: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index a0d07ca44..ebf153e17 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request): """Open a session, and return its unique session id.""" try: session_id = await tokenizer_manager.open_session(obj, request) + if session_id is None: + raise Exception( + "Failed to open the session. Check if a session with the same id is still open." + ) return session_id except Exception as e: return _create_error_response(e) diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index 47169aeaa..5653e9b69 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -1,11 +1,16 @@ """ Usage: python3 -m unittest test_session_control.TestSessionControl.test_session_control +python3 -m unittest test_session_control.TestSessionControl.test_session_control_with_branching +python3 -m unittest test_session_control.TestSessionControl.test_session_control_backtrack_with_abort python3 -m unittest test_session_control.TestSessionControlVision.test_session_control """ +import asyncio +import json import unittest +import aiohttp import requests from sglang.srt.hf_transformers_utils import get_tokenizer @@ -18,6 +23,10 @@ from sglang.test.test_utils import ( ) +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + class TestSessionControl(unittest.TestCase): @classmethod def setUpClass(cls): @@ -31,15 +40,18 @@ class TestSessionControl(unittest.TestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_session_control(self): + def test_session_control(self, gen_len=12): chunks = [ "Let me tell you something about France.", "The capital of France is", + "The population of the city is", "A brief history about that city is", - "To plan a travel, the budget is", ] tokenizer = get_tokenizer(self.model) chunks_ids = [tokenizer.encode(x) for x in chunks] + for i in range(1, len(chunks_ids)): + if chunks_ids[i][0] == tokenizer.bos_token_id: + chunks_ids[i] = chunks_ids[i][1:] # 1. using session control session_id = requests.post( @@ -48,6 +60,13 @@ class TestSessionControl(unittest.TestCase): ).json() rid = None + # open an existing session, should get session_id as None + response = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000, "session_id": session_id}, + ).json() + assert isinstance(response, dict) and "error" in response + first_rid = None outputs_from_session = [] for i, chunk_ids in enumerate(chunks_ids): @@ -55,11 +74,16 @@ class TestSessionControl(unittest.TestCase): self.base_url + "/generate", json={ "input_ids": chunk_ids, - "session": [session_id, rid], + "session_params": { + "id": session_id, + "rid": rid, + "offset": -1, + "replace": True, + }, "sampling_params": { "temperature": 0, "max_new_tokens": ( - 16 if i > 0 else 0 + gen_len if i > 0 else 1 ), # prefill only for the first chunk "no_stop_trim": True, "skip_special_tokens": False, @@ -77,10 +101,15 @@ class TestSessionControl(unittest.TestCase): self.base_url + "/generate", json={ "input_ids": chunks_ids[-1], - "session": [session_id, first_rid], + "session_params": { + "id": session_id, + "rid": first_rid, + "offset": -1, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -93,10 +122,15 @@ class TestSessionControl(unittest.TestCase): self.base_url + "/generate", json={ "input_ids": chunks_ids[-1], - "session": [session_id, rid], + "session_params": { + "id": session_id, + "rid": rid, + "offset": -1, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -115,10 +149,15 @@ class TestSessionControl(unittest.TestCase): self.base_url + "/generate", json={ "input_ids": chunks_ids[-1], - "session": [session_id, first_rid], + "session_params": { + "id": session_id, + "rid": first_rid, + "offset": -1, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -127,6 +166,8 @@ class TestSessionControl(unittest.TestCase): assert response["meta_info"]["finish_reason"]["type"] == "abort" # 2. not use session control + requests.post(self.base_url + "/flush_cache") + input_ids_first_req = None input_ids = [] outputs_normal = [] @@ -139,7 +180,7 @@ class TestSessionControl(unittest.TestCase): "sampling_params": { "temperature": 0, "max_new_tokens": ( - 16 if i > 0 else 0 + gen_len if i > 0 else 1 ), # prefill only for the first chunk "no_stop_trim": True, "skip_special_tokens": False, @@ -150,7 +191,7 @@ class TestSessionControl(unittest.TestCase): output_ids = tokenizer.encode(response["text"]) if output_ids[0] == tokenizer.bos_token_id: output_ids = output_ids[1:] - input_ids += output_ids + input_ids += output_ids[:-1] outputs_normal.append(response["text"]) if i == 0: input_ids_first_req = input_ids.copy() @@ -162,7 +203,7 @@ class TestSessionControl(unittest.TestCase): "input_ids": input_ids_first_req, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -176,6 +217,272 @@ class TestSessionControl(unittest.TestCase): print(outputs_normal) assert outputs_from_session == outputs_normal + async def async_generate(self, payload): + url = self.base_url + "/generate" + async with aiohttp.ClientSession() as session: + async with session.post(url=url, json=payload) as response: + assert response.status == 200 + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + if chunk == "[DONE]": + yield "", None, "" + else: + data = json.loads(chunk) + finish_reason = ( + data["meta_info"]["finish_reason"]["type"] + if data["meta_info"]["finish_reason"] + else "" + ) + yield data["text"], data["meta_info"]["id"], finish_reason + + async def run_session_control_backtrack_with_abort(self, replace): + chunks = [ + "Let me tell you something about France.", + "The capital of France is", + ] + tokenizer = get_tokenizer(self.model) + chunks_ids = [tokenizer.encode(x) for x in chunks] + for i in range(1, len(chunks_ids)): + if chunks_ids[i][0] == tokenizer.bos_token_id: + chunks_ids[i] = chunks_ids[i][1:] + + # 1. using session control + session_id = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000}, + ).json() + rid = None + + payload = { + "input_ids": chunks_ids[0], + "session_params": { + "id": session_id, + "rid": rid, + "offset": -1, + "replace": True, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 100, + "no_stop_trim": True, + "skip_special_tokens": False, + "ignore_eos": True, + }, + "stream": True, + } + gen_so_far = "" + finish_reason = "" + second_output = "" + async for chunk, rid, finish_reason_chunk in self.async_generate(payload): + gen_so_far += chunk + if finish_reason == "": + finish_reason = finish_reason_chunk + if len(gen_so_far) > 50 and second_output == "": + payload2 = { + "input_ids": chunks_ids[1], + "session_params": { + "id": session_id, + "rid": rid, + "offset": 50, + "replace": replace, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + "stream": False, + "stream_output": True, + } + response = requests.post( + url=self.base_url + "/generate", json=payload2 + ).json() + second_output = response["text"] + if replace: + assert finish_reason == "abort" + print("first request output:") + print(gen_so_far) + print("second request output:") + print(second_output) + + # close the session + ret = requests.post( + self.base_url + "/close_session", + json={"session_id": session_id}, + ) + assert ret.status_code == 200 + + if not replace: + assert response["meta_info"]["finish_reason"]["type"] == "abort" + else: + # 2. not using session control + output_ids = tokenizer.encode(gen_so_far) + if output_ids[0] == tokenizer.bos_token_id: + output_ids = output_ids[1:] + input_ids = chunks_ids[0] + output_ids + input_ids = input_ids[:50] + chunks_ids[1] + payload = { + "input_ids": input_ids, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + "stream": False, + "stream_output": True, + } + response = requests.post( + url=self.base_url + "/generate", json=payload + ).json() + output_no_session = response["text"] + print("second request output without session:") + print(output_no_session) + assert second_output == output_no_session + + def test_session_control_backtrack_with_abort(self): + asyncio.run(self.run_session_control_backtrack_with_abort(replace=True)) + asyncio.run(self.run_session_control_backtrack_with_abort(replace=False)) + + def run_session_control_with_branching( + self, root_prompt, chunks_per_step, gen_len=16 + ): + for x in chunks_per_step: + assert len(x) == len(chunks_per_step[0]) + + # 1. using session control + session_id = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000}, + ).json() + + outputs_from_session = [] + # send the root prompt + response = requests.post( + self.base_url + "/generate", + json={ + "text": root_prompt, + "session_params": { + "id": session_id, + "rid": None, + "offset": 0, + "replace": False, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": gen_len, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + rid_per_branch = [response["meta_info"]["id"]] * len(chunks_per_step[0]) + outputs_from_session.append(response["text"]) + + # send the prompts in branches + for chunks_for_branches in chunks_per_step: + for j, chunk in enumerate(chunks_for_branches): + response = requests.post( + self.base_url + "/generate", + json={ + "text": chunk, + "session_params": { + "id": session_id, + "rid": rid_per_branch[j], + "offset": 0, + "replace": False, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": gen_len, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + rid = response["meta_info"]["id"] + rid_per_branch[j] = rid + outputs_from_session.append(response["text"]) + + # close the session + ret = requests.post( + self.base_url + "/close_session", + json={"session_id": session_id}, + ) + assert ret.status_code == 200 + + # 2. not use session control + requests.post(self.base_url + "/flush_cache") + + outputs_normal = [] + input_texts = [root_prompt] * len(chunks_per_step[0]) + # send the root prompt + response = requests.post( + self.base_url + "/generate", + json={ + "text": root_prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": gen_len, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + outputs_normal.append(response["text"]) + input_texts = [x + response["text"] for x in input_texts] + + # send the prompts in branches + for chunks_for_branches in chunks_per_step: + for j, chunk in enumerate(chunks_for_branches): + input_texts[j] += chunk + response = requests.post( + self.base_url + "/generate", + json={ + "text": input_texts[j], + "sampling_params": { + "temperature": 0, + "max_new_tokens": gen_len, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + outputs_normal.append(response["text"]) + input_texts[j] += response["text"] + + print("====== outputs from chunked queries with session control: =======") + print(outputs_from_session) + print("====== outputs from normal queries: =======") + print(outputs_normal) + assert outputs_from_session == outputs_normal + + def test_session_control_with_branching(self): + root_prompt = "First, let me explain in one sentence about AI" + chunks_per_step = [ + [ + "Then, briefly, the positive side of AI is", + "But, briefly, AI could be harmful to human", + ], + ["For example", "For example"], + ] + self.run_session_control_with_branching( + root_prompt=root_prompt, chunks_per_step=chunks_per_step, gen_len=8 + ) + + root_prompt = "I have three apples." + chunks_per_step = [ + ["I then give one apple to my friend", "My friend give me another apple."], + ["I still have", "I now have"], + ] + self.run_session_control_with_branching( + root_prompt=root_prompt, chunks_per_step=chunks_per_step, gen_len=8 + ) + class TestSessionControlVision(unittest.TestCase): @classmethod @@ -197,17 +504,25 @@ class TestSessionControlVision(unittest.TestCase): text_chunks = [ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", "<|im_start|>user\n\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n", - "<|im_start|>user\n\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n", - "<|im_start|>user\n\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\n\nIs this image same with one of the previous images?<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\n\nIs this image same with one of the previous images?<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\nDescribe this image in a very short sentence.<|im_end|>\nassistant:", ] image_chunks = [ + "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", + "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png", - "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", - "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", ] - assert len(text_chunks) == len(image_chunks) + 1 + + assert ( + len(text_chunks) == len(image_chunks) + 2 + ) # the first and the last prompt does not contain images tokenizer = get_tokenizer(self.model) text_input_ids = [tokenizer.encode(x) for x in text_chunks] + for i in range(1, len(text_input_ids)): + if text_input_ids[i][0] == tokenizer.bos_token_id: + text_input_ids[i] = text_input_ids[i][1:] + gen_len = 32 # 1. using session control session_id = requests.post( @@ -216,20 +531,32 @@ class TestSessionControlVision(unittest.TestCase): ).json() rid = None + # open an existing session, should get session_id as None + response = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000, "session_id": session_id}, + ).json() + assert isinstance(response, dict) and "error" in response + first_rid = None outputs_from_session = [] - for i in range(len(text_input_ids)): + for i in range(len(text_input_ids[:-1])): response = requests.post( self.base_url + "/generate", json={ "input_ids": text_input_ids[i], "image_data": image_chunks[i - 1] if i > 0 else None, "modalities": ["multi-images"], - "session": [session_id, rid], + "session_params": { + "id": session_id, + "rid": rid, + "offset": 0, + "replace": True, + }, "sampling_params": { "temperature": 0, "max_new_tokens": ( - 16 if i > 0 else 0 + gen_len if i > 0 else 0 ), # prefill only for the first chunk "no_stop_trim": True, "skip_special_tokens": False, @@ -247,12 +574,15 @@ class TestSessionControlVision(unittest.TestCase): self.base_url + "/generate", json={ "input_ids": text_input_ids[-1], - "image_data": image_chunks[-1:], - "modalities": ["multi-images"], - "session": [session_id, first_rid], + "session_params": { + "id": session_id, + "rid": first_rid, + "offset": 0, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -265,12 +595,15 @@ class TestSessionControlVision(unittest.TestCase): self.base_url + "/generate", json={ "input_ids": text_input_ids[-1], - "image_data": image_chunks[-1:], - "modalities": ["multi-images"], - "session": [session_id, rid], + "session_params": { + "id": session_id, + "rid": rid, + "offset": 0, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -289,10 +622,15 @@ class TestSessionControlVision(unittest.TestCase): self.base_url + "/generate", json={ "input_ids": text_input_ids[-1], - "session": [session_id, first_rid], + "session_params": { + "id": session_id, + "rid": first_rid, + "offset": 0, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -306,7 +644,7 @@ class TestSessionControlVision(unittest.TestCase): input_ids_first_req = None input_ids = [] outputs_normal = [] - for i in range(len(text_input_ids)): + for i in range(len(text_input_ids[:-1])): input_ids += text_input_ids[i] image_data = image_chunks[:i] if i > 0 else None response = requests.post( @@ -318,7 +656,7 @@ class TestSessionControlVision(unittest.TestCase): "sampling_params": { "temperature": 0, "max_new_tokens": ( - 16 if i > 0 else 0 + gen_len if i > 0 else 0 ), # prefill only for the first chunk "no_stop_trim": True, "skip_special_tokens": False, @@ -339,11 +677,9 @@ class TestSessionControlVision(unittest.TestCase): self.base_url + "/generate", json={ "input_ids": input_ids_first_req, - "image_data": image_chunks[-1:], - "modalities": ["multi-images"], "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, },