diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 3c58a6508..73054bf8f 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -173,7 +173,6 @@ class DetokenizerManager: output_strs=output_strs, meta_info=recv_obj.meta_info, finished_reason=recv_obj.finished_reason, - session_ids=recv_obj.session_ids, ) ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index d311f6e2a..9541b2d18 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). import uuid from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams @@ -55,8 +55,9 @@ class GenerateReqInput: lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None # Session id info for continual prompting - session_id: Optional[Union[List[str], str]] = None - session_rid: Optional[Union[List[str], str]] = None + session: Optional[ + Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]] + ] = None def normalize_batch_and_arguments(self): if (self.text is None and self.input_ids is None) or ( @@ -203,7 +204,7 @@ class TokenizedGenerateReqInput: lora_path: Optional[str] = None # None means just use the base model # Session id info for continual prompting - session_id: Optional[int] = None + session_id: Optional[str] = None session_rid: Optional[str] = None @@ -299,8 +300,6 @@ class BatchTokenIDOut: meta_info: List[Dict] finished_reason: List[BaseFinishReason] no_stop_trim: List[bool] - # The updated session unique id - session_ids: List[str] @dataclass @@ -313,8 +312,6 @@ class BatchStrOut: meta_info: List[Dict] # The finish reason finished_reason: List[BaseFinishReason] - # The update session unique id - session_ids: List[str] @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9ca9b7c64..1d1cf3688 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -542,9 +542,7 @@ class Scheduler: else: # Handle sessions session = self.sessions[recv_req.session_id] - req, new_session_id = session.create_req(recv_req, self.tokenizer) - del self.sessions[recv_req.session_id] - self.sessions[new_session_id] = session + req = session.create_req(recv_req, self.tokenizer) if isinstance(req.finished_reason, FINISH_ABORT): self.waiting_queue.append(req) return @@ -1188,7 +1186,6 @@ class Scheduler: output_skip_special_tokens = [] output_spaces_between_special_tokens = [] output_no_stop_trim = [] - output_session_ids = [] else: # embedding or reward model output_embeddings = [] @@ -1216,7 +1213,6 @@ class Scheduler: req.sampling_params.spaces_between_special_tokens ) output_no_stop_trim.append(req.sampling_params.no_stop_trim) - output_session_ids.append(req.session_id) meta_info = { "prompt_tokens": len(req.origin_input_ids), @@ -1267,7 +1263,6 @@ class Scheduler: output_meta_info, output_finished_reason, output_no_stop_trim, - output_session_ids, ) ) else: # embedding or reward model diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index c5c51aade..6a5b5b5b5 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -26,13 +26,13 @@ class Session: self.reqs: List[Req] = [] def create_req(self, req: TokenizedGenerateReqInput, tokenizer): - # renew session id - self.session_id = uuid.uuid4().hex if req.session_rid is not None: while len(self.reqs) > 0: if self.reqs[-1].rid == req.session_rid: break self.reqs = self.reqs[:-1] + else: + self.reqs = [] if len(self.reqs) > 0: input_ids = ( self.reqs[-1].origin_input_ids @@ -58,4 +58,4 @@ class Session: ) else: self.reqs.append(new_req) - return new_req, self.session_id + return new_req diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index cb0f8738e..be58c939f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -216,8 +216,8 @@ class TokenizerManager: return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len top_logprobs_num = obj.top_logprobs_num - session_id = obj.session_id - session_rid = obj.session_rid + session_id = obj.session[0] if obj.session else None + session_rid = obj.session[1] if obj.session else None if len(input_ids) >= self.context_len: raise ValueError( @@ -570,13 +570,11 @@ class TokenizerManager: out_dict = { "text": recv_obj.output_strs[i], "meta_info": recv_obj.meta_info[i], - "session_id": recv_obj.session_ids[i], } elif isinstance(recv_obj, BatchTokenIDOut): out_dict = { "token_ids": recv_obj.output_ids[i], "meta_info": recv_obj.meta_info[i], - "session_id": recv_obj.session_ids[i], } else: assert isinstance(recv_obj, BatchEmbeddingOut) diff --git a/scripts/playground/test_session_id.py b/scripts/playground/test_session_id.py deleted file mode 100644 index d29f7d885..000000000 --- a/scripts/playground/test_session_id.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# FIXME: Make it a CI test - -import requests - -from sglang.srt.hf_transformers_utils import get_tokenizer - -url = "http://localhost:30000" - -# Open a session -response = requests.post( - url + "/open_session", - json={"capacity_of_str_len": 1000}, -) -session_id = response.json() -print("session_id", session_id, "\n") - -# Prefill only -prompt = "chunk 1" -response = requests.post( - url + "/generate", - json={ - "text": prompt, - "session_id": session_id, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 0, - }, - }, -) -print(response.json(), "\n") -session_id = response.json()["session_id"] - -# Generate -prompt = "Chunk 2" -response = requests.post( - url + "/generate", - json={ - "text": prompt, - "session_id": session_id, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 16, - }, - }, -) -print(response.json(), "\n") -session_id = response.json()["session_id"] -rid = response.json()["meta_info"]["id"] - -# Generate -prompt = "Chunk 3" -response = requests.post( - url + "/generate", - json={ - "text": prompt, - "session_id": session_id, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 2, - }, - }, -) -print(response.json(), "\n") -session_id = response.json()["session_id"] -rid_to_del = response.json()["meta_info"]["id"] - -# Interrupt and re-generate -prompt = "Chunk 4" -response = requests.post( - url + "/generate", - json={ - "text": prompt, - "session_id": session_id, - "session_rid": rid, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 16, - }, - }, -) -print(response.json(), "\n") -session_id = response.json()["session_id"] - -# Query a session based on a deleted request, should see finish reason abort -prompt = "Chunk 4" -response = requests.post( - url + "/generate", - json={ - "text": prompt, - "session_id": session_id, - "session_rid": rid_to_del, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 16, - }, - }, -) -print(response.json(), "\n") - -# Close session -ret = requests.post( - url + "/close_session", - json={"session_id": session_id}, -) -print(ret, "\n") - -# Query a deleted session, should see finish reason abort -prompt = "chunk 1" -response = requests.post( - url + "/generate", - json={ - "text": prompt, - "session_id": session_id, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 0, - }, - }, -) -print(response.json(), "\n") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 9bd7b9810..27fe6d7d3 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -34,6 +34,7 @@ suites = { "test_triton_attention_backend.py", "test_update_weights.py", "test_vision_openai_server.py", + "test_session_control.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py new file mode 100644 index 000000000..e5b5e7b6c --- /dev/null +++ b/test/srt/test_session_control.py @@ -0,0 +1,168 @@ +""" +Usage: +python3 -m unittest test_session_control.TestSessionControl.test_session_control +python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm +""" + +import unittest + +import requests + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestSessionControl(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def test_session_control(self): + chunks = [ + "Let me tell you something about France.", + "The capital of France is", + "A brief history about that city is", + "To plan a travel, the budget is", + ] + tokenizer = get_tokenizer(self.model) + chunks_ids = [tokenizer.encode(x) for x in chunks] + + # 1. using session control + session_id = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000}, + ).json() + rid = None + + first_rid = None + outputs_from_session = [] + for i, chunk_ids in enumerate(chunks_ids): + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": chunk_ids, + "session": [session_id, rid], + "sampling_params": { + "temperature": 0, + "max_new_tokens": ( + 16 if i > 0 else 0 + ), # prefill only for the first chunk + }, + }, + ).json() + rid = response["meta_info"]["id"] + if i == 0: + first_rid = rid + if i > 0: + outputs_from_session.append(response["text"]) + + # backtrack to the first request and regenerate + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": chunks_ids[-1], + "session": [session_id, first_rid], + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, + ).json() + outputs_from_session.append(response["text"]) + + # query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": chunks_ids[-1], + "session": [session_id, rid], + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, + ).json() + assert response["meta_info"]["finish_reason"]["type"] == "abort" + + ret = requests.post( + self.base_url + "/close_session", + json={"session_id": session_id}, + ) + assert ret.status_code == 200 + + # send a request to a closed session, should see abort + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": chunks_ids[-1], + "session": [session_id, first_rid], + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, + ).json() + assert response["meta_info"]["finish_reason"]["type"] == "abort" + + # 2. not use session control + input_ids_first_req = None + input_ids = [] + outputs_normal = [] + for i, chunk_ids in enumerate(chunks_ids): + input_ids += chunk_ids + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": 0, + "max_new_tokens": ( + 16 if i > 0 else 0 + ), # prefill only for the first chunk + }, + }, + ).json() + if i > 0: + input_ids += tokenizer.encode(response["text"])[ + 1: + ] # drop the bos token + outputs_normal.append(response["text"]) + if i == 0: + input_ids_first_req = input_ids.copy() + + input_ids_first_req += chunks_ids[-1] + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids_first_req, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, + ).json() + outputs_normal.append(response["text"]) + + print("outputs from chunked queries with session control:") + print(outputs_from_session) + print("outputs from normal queries:") + print(outputs_normal) + assert outputs_from_session == outputs_normal + + +if __name__ == "__main__": + unittest.main()