Files
sglang/test/srt/test_session_control.py

169 lines
5.5 KiB
Python

"""
Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm
"""
import unittest
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestSessionControl(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)
def test_session_control(self):
chunks = [
"Let me tell you something about France.",
"The capital of France is",
"A brief history about that city is",
"To plan a travel, the budget is",
]
tokenizer = get_tokenizer(self.model)
chunks_ids = [tokenizer.encode(x) for x in chunks]
# 1. using session control
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
).json()
rid = None
first_rid = None
outputs_from_session = []
for i, chunk_ids in enumerate(chunks_ids):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunk_ids,
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
},
},
).json()
rid = response["meta_info"]["id"]
if i == 0:
first_rid = rid
if i > 0:
outputs_from_session.append(response["text"])
# backtrack to the first request and regenerate
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
outputs_from_session.append(response["text"])
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
ret = requests.post(
self.base_url + "/close_session",
json={"session_id": session_id},
)
assert ret.status_code == 200
# send a request to a closed session, should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
# 2. not use session control
input_ids_first_req = None
input_ids = []
outputs_normal = []
for i, chunk_ids in enumerate(chunks_ids):
input_ids += chunk_ids
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
},
},
).json()
if i > 0:
input_ids += tokenizer.encode(response["text"])[
1:
] # drop the bos token
outputs_normal.append(response["text"])
if i == 0:
input_ids_first_req = input_ids.copy()
input_ids_first_req += chunks_ids[-1]
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids_first_req,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
outputs_normal.append(response["text"])
print("outputs from chunked queries with session control:")
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
if __name__ == "__main__":
unittest.main()