[feat] Refactor session control interface and add CI (#2173)
This commit is contained in:
@@ -34,6 +34,7 @@ suites = {
|
||||
"test_triton_attention_backend.py",
|
||||
"test_update_weights.py",
|
||||
"test_vision_openai_server.py",
|
||||
"test_session_control.py",
|
||||
],
|
||||
"sampling/penaltylib": glob.glob(
|
||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||
|
||||
168
test/srt/test_session_control.py
Normal file
168
test/srt/test_session_control.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestSessionControl(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_session_control(self):
|
||||
chunks = [
|
||||
"Let me tell you something about France.",
|
||||
"The capital of France is",
|
||||
"A brief history about that city is",
|
||||
"To plan a travel, the budget is",
|
||||
]
|
||||
tokenizer = get_tokenizer(self.model)
|
||||
chunks_ids = [tokenizer.encode(x) for x in chunks]
|
||||
|
||||
# 1. using session control
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
).json()
|
||||
rid = None
|
||||
|
||||
first_rid = None
|
||||
outputs_from_session = []
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunk_ids,
|
||||
"session": [session_id, rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
},
|
||||
},
|
||||
).json()
|
||||
rid = response["meta_info"]["id"]
|
||||
if i == 0:
|
||||
first_rid = rid
|
||||
if i > 0:
|
||||
outputs_from_session.append(response["text"])
|
||||
|
||||
# backtrack to the first request and regenerate
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, first_rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
outputs_from_session.append(response["text"])
|
||||
|
||||
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
|
||||
ret = requests.post(
|
||||
self.base_url + "/close_session",
|
||||
json={"session_id": session_id},
|
||||
)
|
||||
assert ret.status_code == 200
|
||||
|
||||
# send a request to a closed session, should see abort
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunks_ids[-1],
|
||||
"session": [session_id, first_rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
|
||||
# 2. not use session control
|
||||
input_ids_first_req = None
|
||||
input_ids = []
|
||||
outputs_normal = []
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
input_ids += chunk_ids
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
},
|
||||
},
|
||||
).json()
|
||||
if i > 0:
|
||||
input_ids += tokenizer.encode(response["text"])[
|
||||
1:
|
||||
] # drop the bos token
|
||||
outputs_normal.append(response["text"])
|
||||
if i == 0:
|
||||
input_ids_first_req = input_ids.copy()
|
||||
|
||||
input_ids_first_req += chunks_ids[-1]
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids_first_req,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
outputs_normal.append(response["text"])
|
||||
|
||||
print("outputs from chunked queries with session control:")
|
||||
print(outputs_from_session)
|
||||
print("outputs from normal queries:")
|
||||
print(outputs_normal)
|
||||
assert outputs_from_session == outputs_normal
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user