[feat] Refactor session control interface and add CI (#2173)
This commit is contained in:
@@ -173,7 +173,6 @@ class DetokenizerManager:
|
|||||||
output_strs=output_strs,
|
output_strs=output_strs,
|
||||||
meta_info=recv_obj.meta_info,
|
meta_info=recv_obj.meta_info,
|
||||||
finished_reason=recv_obj.finished_reason,
|
finished_reason=recv_obj.finished_reason,
|
||||||
session_ids=recv_obj.session_ids,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -55,8 +55,9 @@ class GenerateReqInput:
|
|||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
|
||||||
# Session id info for continual prompting
|
# Session id info for continual prompting
|
||||||
session_id: Optional[Union[List[str], str]] = None
|
session: Optional[
|
||||||
session_rid: Optional[Union[List[str], str]] = None
|
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
||||||
|
] = None
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (self.text is None and self.input_ids is None) or (
|
||||||
@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput:
|
|||||||
lora_path: Optional[str] = None # None means just use the base model
|
lora_path: Optional[str] = None # None means just use the base model
|
||||||
|
|
||||||
# Session id info for continual prompting
|
# Session id info for continual prompting
|
||||||
session_id: Optional[int] = None
|
session_id: Optional[str] = None
|
||||||
session_rid: Optional[str] = None
|
session_rid: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -299,8 +300,6 @@ class BatchTokenIDOut:
|
|||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
finished_reason: List[BaseFinishReason]
|
finished_reason: List[BaseFinishReason]
|
||||||
no_stop_trim: List[bool]
|
no_stop_trim: List[bool]
|
||||||
# The updated session unique id
|
|
||||||
session_ids: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -313,8 +312,6 @@ class BatchStrOut:
|
|||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
# The finish reason
|
# The finish reason
|
||||||
finished_reason: List[BaseFinishReason]
|
finished_reason: List[BaseFinishReason]
|
||||||
# The update session unique id
|
|
||||||
session_ids: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -542,9 +542,7 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
# Handle sessions
|
# Handle sessions
|
||||||
session = self.sessions[recv_req.session_id]
|
session = self.sessions[recv_req.session_id]
|
||||||
req, new_session_id = session.create_req(recv_req, self.tokenizer)
|
req = session.create_req(recv_req, self.tokenizer)
|
||||||
del self.sessions[recv_req.session_id]
|
|
||||||
self.sessions[new_session_id] = session
|
|
||||||
if isinstance(req.finished_reason, FINISH_ABORT):
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
return
|
return
|
||||||
@@ -1188,7 +1186,6 @@ class Scheduler:
|
|||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_spaces_between_special_tokens = []
|
output_spaces_between_special_tokens = []
|
||||||
output_no_stop_trim = []
|
output_no_stop_trim = []
|
||||||
output_session_ids = []
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
output_embeddings = []
|
output_embeddings = []
|
||||||
|
|
||||||
@@ -1216,7 +1213,6 @@ class Scheduler:
|
|||||||
req.sampling_params.spaces_between_special_tokens
|
req.sampling_params.spaces_between_special_tokens
|
||||||
)
|
)
|
||||||
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||||
output_session_ids.append(req.session_id)
|
|
||||||
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"prompt_tokens": len(req.origin_input_ids),
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
@@ -1267,7 +1263,6 @@ class Scheduler:
|
|||||||
output_meta_info,
|
output_meta_info,
|
||||||
output_finished_reason,
|
output_finished_reason,
|
||||||
output_no_stop_trim,
|
output_no_stop_trim,
|
||||||
output_session_ids,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
|
|||||||
@@ -26,13 +26,13 @@ class Session:
|
|||||||
self.reqs: List[Req] = []
|
self.reqs: List[Req] = []
|
||||||
|
|
||||||
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
||||||
# renew session id
|
|
||||||
self.session_id = uuid.uuid4().hex
|
|
||||||
if req.session_rid is not None:
|
if req.session_rid is not None:
|
||||||
while len(self.reqs) > 0:
|
while len(self.reqs) > 0:
|
||||||
if self.reqs[-1].rid == req.session_rid:
|
if self.reqs[-1].rid == req.session_rid:
|
||||||
break
|
break
|
||||||
self.reqs = self.reqs[:-1]
|
self.reqs = self.reqs[:-1]
|
||||||
|
else:
|
||||||
|
self.reqs = []
|
||||||
if len(self.reqs) > 0:
|
if len(self.reqs) > 0:
|
||||||
input_ids = (
|
input_ids = (
|
||||||
self.reqs[-1].origin_input_ids
|
self.reqs[-1].origin_input_ids
|
||||||
@@ -58,4 +58,4 @@ class Session:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.reqs.append(new_req)
|
self.reqs.append(new_req)
|
||||||
return new_req, self.session_id
|
return new_req
|
||||||
|
|||||||
@@ -216,8 +216,8 @@ class TokenizerManager:
|
|||||||
return_logprob = obj.return_logprob
|
return_logprob = obj.return_logprob
|
||||||
logprob_start_len = obj.logprob_start_len
|
logprob_start_len = obj.logprob_start_len
|
||||||
top_logprobs_num = obj.top_logprobs_num
|
top_logprobs_num = obj.top_logprobs_num
|
||||||
session_id = obj.session_id
|
session_id = obj.session[0] if obj.session else None
|
||||||
session_rid = obj.session_rid
|
session_rid = obj.session[1] if obj.session else None
|
||||||
|
|
||||||
if len(input_ids) >= self.context_len:
|
if len(input_ids) >= self.context_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -570,13 +570,11 @@ class TokenizerManager:
|
|||||||
out_dict = {
|
out_dict = {
|
||||||
"text": recv_obj.output_strs[i],
|
"text": recv_obj.output_strs[i],
|
||||||
"meta_info": recv_obj.meta_info[i],
|
"meta_info": recv_obj.meta_info[i],
|
||||||
"session_id": recv_obj.session_ids[i],
|
|
||||||
}
|
}
|
||||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"token_ids": recv_obj.output_ids[i],
|
"token_ids": recv_obj.output_ids[i],
|
||||||
"meta_info": recv_obj.meta_info[i],
|
"meta_info": recv_obj.meta_info[i],
|
||||||
"session_id": recv_obj.session_ids[i],
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||||
|
|||||||
@@ -1,132 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
# FIXME: Make it a CI test
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
||||||
|
|
||||||
url = "http://localhost:30000"
|
|
||||||
|
|
||||||
# Open a session
|
|
||||||
response = requests.post(
|
|
||||||
url + "/open_session",
|
|
||||||
json={"capacity_of_str_len": 1000},
|
|
||||||
)
|
|
||||||
session_id = response.json()
|
|
||||||
print("session_id", session_id, "\n")
|
|
||||||
|
|
||||||
# Prefill only
|
|
||||||
prompt = "chunk 1"
|
|
||||||
response = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": prompt,
|
|
||||||
"session_id": session_id,
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print(response.json(), "\n")
|
|
||||||
session_id = response.json()["session_id"]
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
prompt = "Chunk 2"
|
|
||||||
response = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": prompt,
|
|
||||||
"session_id": session_id,
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 16,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print(response.json(), "\n")
|
|
||||||
session_id = response.json()["session_id"]
|
|
||||||
rid = response.json()["meta_info"]["id"]
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
prompt = "Chunk 3"
|
|
||||||
response = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": prompt,
|
|
||||||
"session_id": session_id,
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print(response.json(), "\n")
|
|
||||||
session_id = response.json()["session_id"]
|
|
||||||
rid_to_del = response.json()["meta_info"]["id"]
|
|
||||||
|
|
||||||
# Interrupt and re-generate
|
|
||||||
prompt = "Chunk 4"
|
|
||||||
response = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": prompt,
|
|
||||||
"session_id": session_id,
|
|
||||||
"session_rid": rid,
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 16,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print(response.json(), "\n")
|
|
||||||
session_id = response.json()["session_id"]
|
|
||||||
|
|
||||||
# Query a session based on a deleted request, should see finish reason abort
|
|
||||||
prompt = "Chunk 4"
|
|
||||||
response = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": prompt,
|
|
||||||
"session_id": session_id,
|
|
||||||
"session_rid": rid_to_del,
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 16,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print(response.json(), "\n")
|
|
||||||
|
|
||||||
# Close session
|
|
||||||
ret = requests.post(
|
|
||||||
url + "/close_session",
|
|
||||||
json={"session_id": session_id},
|
|
||||||
)
|
|
||||||
print(ret, "\n")
|
|
||||||
|
|
||||||
# Query a deleted session, should see finish reason abort
|
|
||||||
prompt = "chunk 1"
|
|
||||||
response = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": prompt,
|
|
||||||
"session_id": session_id,
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print(response.json(), "\n")
|
|
||||||
@@ -34,6 +34,7 @@ suites = {
|
|||||||
"test_triton_attention_backend.py",
|
"test_triton_attention_backend.py",
|
||||||
"test_update_weights.py",
|
"test_update_weights.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
|
"test_session_control.py",
|
||||||
],
|
],
|
||||||
"sampling/penaltylib": glob.glob(
|
"sampling/penaltylib": glob.glob(
|
||||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||||
|
|||||||
168
test/srt/test_session_control.py
Normal file
168
test/srt/test_session_control.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_session_control.TestSessionControl.test_session_control
|
||||||
|
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionControl(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
|
def test_session_control(self):
|
||||||
|
chunks = [
|
||||||
|
"Let me tell you something about France.",
|
||||||
|
"The capital of France is",
|
||||||
|
"A brief history about that city is",
|
||||||
|
"To plan a travel, the budget is",
|
||||||
|
]
|
||||||
|
tokenizer = get_tokenizer(self.model)
|
||||||
|
chunks_ids = [tokenizer.encode(x) for x in chunks]
|
||||||
|
|
||||||
|
# 1. using session control
|
||||||
|
session_id = requests.post(
|
||||||
|
self.base_url + "/open_session",
|
||||||
|
json={"capacity_of_str_len": 1000},
|
||||||
|
).json()
|
||||||
|
rid = None
|
||||||
|
|
||||||
|
first_rid = None
|
||||||
|
outputs_from_session = []
|
||||||
|
for i, chunk_ids in enumerate(chunks_ids):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": chunk_ids,
|
||||||
|
"session": [session_id, rid],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": (
|
||||||
|
16 if i > 0 else 0
|
||||||
|
), # prefill only for the first chunk
|
||||||
|
},
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
rid = response["meta_info"]["id"]
|
||||||
|
if i == 0:
|
||||||
|
first_rid = rid
|
||||||
|
if i > 0:
|
||||||
|
outputs_from_session.append(response["text"])
|
||||||
|
|
||||||
|
# backtrack to the first request and regenerate
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": chunks_ids[-1],
|
||||||
|
"session": [session_id, first_rid],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
outputs_from_session.append(response["text"])
|
||||||
|
|
||||||
|
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": chunks_ids[-1],
|
||||||
|
"session": [session_id, rid],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||||
|
|
||||||
|
ret = requests.post(
|
||||||
|
self.base_url + "/close_session",
|
||||||
|
json={"session_id": session_id},
|
||||||
|
)
|
||||||
|
assert ret.status_code == 200
|
||||||
|
|
||||||
|
# send a request to a closed session, should see abort
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": chunks_ids[-1],
|
||||||
|
"session": [session_id, first_rid],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||||
|
|
||||||
|
# 2. not use session control
|
||||||
|
input_ids_first_req = None
|
||||||
|
input_ids = []
|
||||||
|
outputs_normal = []
|
||||||
|
for i, chunk_ids in enumerate(chunks_ids):
|
||||||
|
input_ids += chunk_ids
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": (
|
||||||
|
16 if i > 0 else 0
|
||||||
|
), # prefill only for the first chunk
|
||||||
|
},
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
if i > 0:
|
||||||
|
input_ids += tokenizer.encode(response["text"])[
|
||||||
|
1:
|
||||||
|
] # drop the bos token
|
||||||
|
outputs_normal.append(response["text"])
|
||||||
|
if i == 0:
|
||||||
|
input_ids_first_req = input_ids.copy()
|
||||||
|
|
||||||
|
input_ids_first_req += chunks_ids[-1]
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": input_ids_first_req,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
outputs_normal.append(response["text"])
|
||||||
|
|
||||||
|
print("outputs from chunked queries with session control:")
|
||||||
|
print(outputs_from_session)
|
||||||
|
print("outputs from normal queries:")
|
||||||
|
print(outputs_normal)
|
||||||
|
assert outputs_from_session == outputs_normal
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user