[feat] Add session control (#2073)
This commit is contained in:
@@ -175,6 +175,7 @@ class DetokenizerManager:
|
||||
output_strs=output_strs,
|
||||
meta_info=recv_obj.meta_info,
|
||||
finished_reason=recv_obj.finished_reason,
|
||||
session_ids=recv_obj.session_ids,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -56,6 +56,10 @@ class GenerateReqInput:
|
||||
# LoRA related
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[Union[List[str], str]] = None
|
||||
session_rid: Optional[Union[List[str], str]] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
|
||||
# LoRA related
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[int] = None
|
||||
session_rid: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
@@ -293,6 +301,8 @@ class BatchTokenIDOut:
|
||||
meta_info: List[Dict]
|
||||
finished_reason: List[BaseFinishReason]
|
||||
no_stop_trim: List[bool]
|
||||
# The updated session unique id
|
||||
session_ids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -305,6 +315,8 @@ class BatchStrOut:
|
||||
meta_info: List[Dict]
|
||||
# The finish reason
|
||||
finished_reason: List[BaseFinishReason]
|
||||
# The update session unique id
|
||||
session_ids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
|
||||
@dataclass
|
||||
class GetMemPoolSizeReqOutput:
|
||||
size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenSessionReqInput:
|
||||
capacity_of_str_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CloseSessionReqInput:
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenSessionReqOutput:
|
||||
session_id: str
|
||||
|
||||
@@ -180,6 +180,7 @@ class Req:
|
||||
origin_input_ids: Tuple[int],
|
||||
sampling_params: SamplingParams,
|
||||
lora_path: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -188,6 +189,8 @@ class Req:
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||
self.session_id = session_id
|
||||
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
|
||||
|
||||
@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
FlushCacheReq,
|
||||
GetMemPoolSizeReq,
|
||||
GetMemPoolSizeReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
|
||||
PrefillAdder,
|
||||
SchedulePolicy,
|
||||
)
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
@@ -106,6 +110,9 @@ class Scheduler:
|
||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
|
||||
# Session info
|
||||
self.sessions = {}
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
@@ -509,6 +516,11 @@ class Scheduler:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
elif isinstance(recv_req, OpenSessionReqInput):
|
||||
session_id = self.open_session(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
||||
elif isinstance(recv_req, CloseSessionReqInput):
|
||||
self.close_session(recv_req)
|
||||
elif isinstance(recv_req, GetMemPoolSizeReq):
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
||||
@@ -520,14 +532,30 @@ class Scheduler:
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
lora_path=recv_req.lora_path,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
lora_path=recv_req.lora_path,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
if recv_req.session_id is not None:
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
f"Invalid request: session id {recv_req.session_id} does not exist"
|
||||
)
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
else:
|
||||
# Handle sessions
|
||||
session = self.sessions[recv_req.session_id]
|
||||
req, new_session_id = session.create_req(recv_req, self.tokenizer)
|
||||
del self.sessions[recv_req.session_id]
|
||||
self.sessions[new_session_id] = session
|
||||
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Image inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
@@ -1151,6 +1179,7 @@ class Scheduler:
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_no_stop_trim = []
|
||||
output_session_ids = []
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
|
||||
@@ -1178,6 +1207,7 @@ class Scheduler:
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
output_session_ids.append(req.session_id)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
@@ -1228,6 +1258,7 @@ class Scheduler:
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
output_no_stop_trim,
|
||||
output_session_ids,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
@@ -1330,6 +1361,25 @@ class Scheduler:
|
||||
)
|
||||
logger.info("Profiler is done")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
if session_id in self.sessions:
|
||||
logger.warning(f"session id {session_id} already exist, cannot open.")
|
||||
else:
|
||||
self.sessions[session_id] = Session(
|
||||
recv_req.capacity_of_str_len, session_id
|
||||
)
|
||||
return session_id
|
||||
|
||||
def close_session(self, recv_req: CloseSessionReqInput):
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
if session_id not in self.sessions:
|
||||
logger.warning(f"session id {session_id} does not exist, cannot delete.")
|
||||
else:
|
||||
del self.sessions[session_id]
|
||||
|
||||
|
||||
def run_scheduler_process(
|
||||
server_args: ServerArgs,
|
||||
|
||||
62
python/sglang/srt/managers/session_controller.py
Normal file
62
python/sglang/srt/managers/session_controller.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
||||
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
||||
self.capacity_of_str_len = capacity_of_str_len
|
||||
self.reqs: List[Req] = []
|
||||
|
||||
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
||||
# renew session id
|
||||
self.session_id = uuid.uuid4().hex
|
||||
if req.session_rid is not None:
|
||||
while len(self.reqs) > 0:
|
||||
if self.reqs[-1].rid == req.session_rid:
|
||||
break
|
||||
self.reqs = self.reqs[:-1]
|
||||
if len(self.reqs) > 0:
|
||||
input_ids = (
|
||||
self.reqs[-1].origin_input_ids
|
||||
+ self.reqs[-1].output_ids[
|
||||
: self.reqs[-1].sampling_params.max_new_tokens
|
||||
]
|
||||
+ req.input_ids
|
||||
)
|
||||
else:
|
||||
input_ids = req.input_ids
|
||||
new_req = Req(
|
||||
req.rid,
|
||||
None,
|
||||
input_ids,
|
||||
req.sampling_params,
|
||||
lora_path=req.lora_path,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
new_req.tokenizer = tokenizer
|
||||
if req.session_rid is not None and len(self.reqs) == 0:
|
||||
new_req.finished_reason = FINISH_ABORT(
|
||||
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
||||
)
|
||||
else:
|
||||
self.reqs.append(new_req)
|
||||
return new_req, self.session_id
|
||||
@@ -23,6 +23,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetMemPoolSizeReq,
|
||||
GetMemPoolSizeReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -146,6 +150,9 @@ class TokenizerManager:
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
|
||||
# For session info
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
|
||||
# Others
|
||||
self.gracefully_exit = False
|
||||
|
||||
@@ -211,6 +218,8 @@ class TokenizerManager:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
session_id = obj.session_id
|
||||
session_rid = obj.session_rid
|
||||
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
@@ -236,6 +245,8 @@ class TokenizerManager:
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
obj.lora_path,
|
||||
session_id=session_id,
|
||||
session_rid=session_rid,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -451,6 +462,26 @@ class TokenizerManager:
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
session_id = uuid.uuid4().hex
|
||||
obj.session_id = session_id
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.session_futures[session_id] = asyncio.Future()
|
||||
session_id = await self.session_futures[session_id]
|
||||
del self.session_futures[session_id]
|
||||
return session_id
|
||||
|
||||
async def close_session(
|
||||
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
assert not self.to_create_loop, "close session should not be the first request"
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
async def abort_request():
|
||||
@@ -521,6 +552,11 @@ class TokenizerManager:
|
||||
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
||||
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
)
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||
@@ -536,11 +572,13 @@ class TokenizerManager:
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"session_id": recv_obj.session_ids[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"session_id": recv_obj.session_ids[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
|
||||
@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
try:
|
||||
session_id = await tokenizer_manager.open_session(obj, request)
|
||||
return session_id
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||
async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
"""Close the session"""
|
||||
try:
|
||||
await tokenizer_manager.close_session(obj, request)
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@time_func_latency
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
|
||||
133
test/srt/test_session_id.py
Normal file
133
test/srt/test_session_id.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# FIXME: Make it a CI test
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
url = "http://localhost:30000"
|
||||
|
||||
# Open a session
|
||||
response = requests.post(
|
||||
url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
)
|
||||
session_id = response.json()
|
||||
print("session_id", session_id, "\n")
|
||||
|
||||
# Prefill only
|
||||
prompt = "chunk 1"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 0,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
session_id = response.json()["session_id"]
|
||||
|
||||
# Generate
|
||||
prompt = "Chunk 2"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
session_id = response.json()["session_id"]
|
||||
rid = response.json()["meta_info"]["id"]
|
||||
|
||||
# Generate
|
||||
prompt = "Chunk 3"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 2,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
session_id = response.json()["session_id"]
|
||||
rid_to_del = response.json()["meta_info"]["id"]
|
||||
|
||||
# Interrupt and re-generate
|
||||
prompt = "Chunk 4"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"session_rid": rid,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
session_id = response.json()["session_id"]
|
||||
|
||||
# Query a session based on a deleted request, should see finish reason abort
|
||||
prompt = "Chunk 4"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"session_rid": rid_to_del,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
|
||||
# Close session
|
||||
ret = requests.post(
|
||||
url + "/close_session",
|
||||
json={"session_id": session_id},
|
||||
)
|
||||
print(ret, "\n")
|
||||
|
||||
# Query a deleted session, should see finish reason abort
|
||||
prompt = "chunk 1"
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"session_id": session_id,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 0,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json(), "\n")
|
||||
Reference in New Issue
Block a user