[feat] Refactor session control interface and add CI (#2173)
This commit is contained in:
@@ -173,7 +173,6 @@ class DetokenizerManager:
|
||||
output_strs=output_strs,
|
||||
meta_info=recv_obj.meta_info,
|
||||
finished_reason=recv_obj.finished_reason,
|
||||
session_ids=recv_obj.session_ids,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -55,8 +55,9 @@ class GenerateReqInput:
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[Union[List[str], str]] = None
|
||||
session_rid: Optional[Union[List[str], str]] = None
|
||||
session: Optional[
|
||||
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
||||
] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput:
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[int] = None
|
||||
session_id: Optional[str] = None
|
||||
session_rid: Optional[str] = None
|
||||
|
||||
|
||||
@@ -299,8 +300,6 @@ class BatchTokenIDOut:
|
||||
meta_info: List[Dict]
|
||||
finished_reason: List[BaseFinishReason]
|
||||
no_stop_trim: List[bool]
|
||||
# The updated session unique id
|
||||
session_ids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -313,8 +312,6 @@ class BatchStrOut:
|
||||
meta_info: List[Dict]
|
||||
# The finish reason
|
||||
finished_reason: List[BaseFinishReason]
|
||||
# The update session unique id
|
||||
session_ids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -542,9 +542,7 @@ class Scheduler:
|
||||
else:
|
||||
# Handle sessions
|
||||
session = self.sessions[recv_req.session_id]
|
||||
req, new_session_id = session.create_req(recv_req, self.tokenizer)
|
||||
del self.sessions[recv_req.session_id]
|
||||
self.sessions[new_session_id] = session
|
||||
req = session.create_req(recv_req, self.tokenizer)
|
||||
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
@@ -1188,7 +1186,6 @@ class Scheduler:
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_no_stop_trim = []
|
||||
output_session_ids = []
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
|
||||
@@ -1216,7 +1213,6 @@ class Scheduler:
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
output_session_ids.append(req.session_id)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
@@ -1267,7 +1263,6 @@ class Scheduler:
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
output_no_stop_trim,
|
||||
output_session_ids,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
|
||||
@@ -26,13 +26,13 @@ class Session:
|
||||
self.reqs: List[Req] = []
|
||||
|
||||
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
||||
# renew session id
|
||||
self.session_id = uuid.uuid4().hex
|
||||
if req.session_rid is not None:
|
||||
while len(self.reqs) > 0:
|
||||
if self.reqs[-1].rid == req.session_rid:
|
||||
break
|
||||
self.reqs = self.reqs[:-1]
|
||||
else:
|
||||
self.reqs = []
|
||||
if len(self.reqs) > 0:
|
||||
input_ids = (
|
||||
self.reqs[-1].origin_input_ids
|
||||
@@ -58,4 +58,4 @@ class Session:
|
||||
)
|
||||
else:
|
||||
self.reqs.append(new_req)
|
||||
return new_req, self.session_id
|
||||
return new_req
|
||||
|
||||
@@ -216,8 +216,8 @@ class TokenizerManager:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
session_id = obj.session_id
|
||||
session_rid = obj.session_rid
|
||||
session_id = obj.session[0] if obj.session else None
|
||||
session_rid = obj.session[1] if obj.session else None
|
||||
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
@@ -570,13 +570,11 @@ class TokenizerManager:
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"session_id": recv_obj.session_ids[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"session_id": recv_obj.session_ids[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# FIXME: Make it a CI test
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
url = "http://localhost:30000"
|
||||
|
||||
# Open a session
|
||||
response = requests.post(
|
||||
url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
)
|
||||
session_id = response.json()
|
||||
print("session_id", session_id, "\n")
|
||||
|
||||
# Prefill only
|
||||
prompt = "chunk 1"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 0,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
session_id = response.json()["session_id"]
|
||||
|
||||
# Generate
|
||||
prompt = "Chunk 2"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
session_id = response.json()["session_id"]
|
||||
rid = response.json()["meta_info"]["id"]
|
||||
|
||||
# Generate
|
||||
prompt = "Chunk 3"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 2,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
session_id = response.json()["session_id"]
|
||||
rid_to_del = response.json()["meta_info"]["id"]
|
||||
|
||||
# Interrupt and re-generate
|
||||
prompt = "Chunk 4"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"session_rid": rid,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
session_id = response.json()["session_id"]
|
||||
|
||||
# Query a session based on a deleted request, should see finish reason abort
|
||||
prompt = "Chunk 4"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"session_rid": rid_to_del,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
|
||||
# Close session
|
||||
ret = requests.post(
|
||||
url + "/close_session",
|
||||
json={"session_id": session_id},
|
||||
)
|
||||
print(ret, "\n")
|
||||
|
||||
# Query a deleted session, should see finish reason abort
|
||||
prompt = "chunk 1"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 0,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
@@ -34,6 +34,7 @@ suites = {
|
||||
"test_triton_attention_backend.py",
|
||||
"test_update_weights.py",
|
||||
"test_vision_openai_server.py",
|
||||
"test_session_control.py",
|
||||
],
|
||||
"sampling/penaltylib": glob.glob(
|
||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||
|
||||
168
test/srt/test_session_control.py
Normal file
168
test/srt/test_session_control.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestSessionControl(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_session_control(self):
|
||||
chunks = [
|
||||
"Let me tell you something about France.",
|
||||
"The capital of France is",
|
||||
"A brief history about that city is",
|
||||
"To plan a travel, the budget is",
|
||||
]
|
||||
tokenizer = get_tokenizer(self.model)
|
||||
chunks_ids = [tokenizer.encode(x) for x in chunks]
|
||||
|
||||
# 1. using session control
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
).json()
|
||||
rid = None
|
||||
|
||||
first_rid = None
|
||||
outputs_from_session = []
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunk_ids,
|
||||
"session": [session_id, rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
},
|
||||
},
|
||||
).json()
|
||||
rid = response["meta_info"]["id"]
|
||||
if i == 0:
|
||||
first_rid = rid
|
||||
if i > 0:
|
||||
outputs_from_session.append(response["text"])
|
||||
|
||||
# backtrack to the first request and regenerate
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, first_rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
outputs_from_session.append(response["text"])
|
||||
|
||||
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
|
||||
ret = requests.post(
|
||||
self.base_url + "/close_session",
|
||||
json={"session_id": session_id},
|
||||
)
|
||||
assert ret.status_code == 200
|
||||
|
||||
# send a request to a closed session, should see abort
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, first_rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
|
||||
# 2. not use session control
|
||||
input_ids_first_req = None
|
||||
input_ids = []
|
||||
outputs_normal = []
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
input_ids += chunk_ids
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
},
|
||||
},
|
||||
).json()
|
||||
if i > 0:
|
||||
input_ids += tokenizer.encode(response["text"])[
|
||||
1:
|
||||
] # drop the bos token
|
||||
outputs_normal.append(response["text"])
|
||||
if i == 0:
|
||||
input_ids_first_req = input_ids.copy()
|
||||
|
||||
input_ids_first_req += chunks_ids[-1]
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids_first_req,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
outputs_normal.append(response["text"])
|
||||
|
||||
print("outputs from chunked queries with session control:")
|
||||
print(outputs_from_session)
|
||||
print("outputs from normal queries:")
|
||||
print(outputs_normal)
|
||||
assert outputs_from_session == outputs_normal
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user