[feat] Add session control (#2073)
This commit is contained in:
@@ -175,6 +175,7 @@ class DetokenizerManager:
|
|||||||
output_strs=output_strs,
|
output_strs=output_strs,
|
||||||
meta_info=recv_obj.meta_info,
|
meta_info=recv_obj.meta_info,
|
||||||
finished_reason=recv_obj.finished_reason,
|
finished_reason=recv_obj.finished_reason,
|
||||||
|
session_ids=recv_obj.session_ids,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,10 @@ class GenerateReqInput:
|
|||||||
# LoRA related
|
# LoRA related
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
|
||||||
|
# Session id info for continual prompting
|
||||||
|
session_id: Optional[Union[List[str], str]] = None
|
||||||
|
session_rid: Optional[Union[List[str], str]] = None
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (self.text is None and self.input_ids is None) or (
|
||||||
self.text is not None and self.input_ids is not None
|
self.text is not None and self.input_ids is not None
|
||||||
@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
|
|||||||
# LoRA related
|
# LoRA related
|
||||||
lora_path: Optional[str] = None # None means just use the base model
|
lora_path: Optional[str] = None # None means just use the base model
|
||||||
|
|
||||||
|
# Session id info for continual prompting
|
||||||
|
session_id: Optional[int] = None
|
||||||
|
session_rid: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingReqInput:
|
class EmbeddingReqInput:
|
||||||
@@ -293,6 +301,8 @@ class BatchTokenIDOut:
|
|||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
finished_reason: List[BaseFinishReason]
|
finished_reason: List[BaseFinishReason]
|
||||||
no_stop_trim: List[bool]
|
no_stop_trim: List[bool]
|
||||||
|
# The updated session unique id
|
||||||
|
session_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -305,6 +315,8 @@ class BatchStrOut:
|
|||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
# The finish reason
|
# The finish reason
|
||||||
finished_reason: List[BaseFinishReason]
|
finished_reason: List[BaseFinishReason]
|
||||||
|
# The update session unique id
|
||||||
|
session_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GetMemPoolSizeReqOutput:
|
class GetMemPoolSizeReqOutput:
|
||||||
size: int
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenSessionReqInput:
|
||||||
|
capacity_of_str_len: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CloseSessionReqInput:
|
||||||
|
session_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenSessionReqOutput:
|
||||||
|
session_id: str
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ class Req:
|
|||||||
origin_input_ids: Tuple[int],
|
origin_input_ids: Tuple[int],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
lora_path: Optional[str] = None,
|
lora_path: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
# Input and output info
|
# Input and output info
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
@@ -188,6 +189,8 @@ class Req:
|
|||||||
self.origin_input_ids = origin_input_ids
|
self.origin_input_ids = origin_input_ids
|
||||||
self.output_ids = [] # Each decode stage's output ids
|
self.output_ids = [] # Each decode stage's output ids
|
||||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.lora_path = lora_path
|
self.lora_path = lora_path
|
||||||
|
|
||||||
|
|||||||
@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
|
|||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
|
CloseSessionReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GetMemPoolSizeReq,
|
GetMemPoolSizeReq,
|
||||||
GetMemPoolSizeReqOutput,
|
GetMemPoolSizeReqOutput,
|
||||||
|
OpenSessionReqInput,
|
||||||
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
|
|||||||
PrefillAdder,
|
PrefillAdder,
|
||||||
SchedulePolicy,
|
SchedulePolicy,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.session_controller import Session
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
@@ -106,6 +110,9 @@ class Scheduler:
|
|||||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||||
self.enable_metrics = server_args.enable_metrics
|
self.enable_metrics = server_args.enable_metrics
|
||||||
|
|
||||||
|
# Session info
|
||||||
|
self.sessions = {}
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
|
|
||||||
@@ -509,6 +516,11 @@ class Scheduler:
|
|||||||
self.start_profile()
|
self.start_profile()
|
||||||
else:
|
else:
|
||||||
self.stop_profile()
|
self.stop_profile()
|
||||||
|
elif isinstance(recv_req, OpenSessionReqInput):
|
||||||
|
session_id = self.open_session(recv_req)
|
||||||
|
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
||||||
|
elif isinstance(recv_req, CloseSessionReqInput):
|
||||||
|
self.close_session(recv_req)
|
||||||
elif isinstance(recv_req, GetMemPoolSizeReq):
|
elif isinstance(recv_req, GetMemPoolSizeReq):
|
||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
||||||
@@ -520,14 +532,30 @@ class Scheduler:
|
|||||||
self,
|
self,
|
||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
):
|
):
|
||||||
req = Req(
|
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||||
recv_req.rid,
|
req = Req(
|
||||||
recv_req.input_text,
|
recv_req.rid,
|
||||||
recv_req.input_ids,
|
recv_req.input_text,
|
||||||
recv_req.sampling_params,
|
recv_req.input_ids,
|
||||||
lora_path=recv_req.lora_path,
|
recv_req.sampling_params,
|
||||||
)
|
lora_path=recv_req.lora_path,
|
||||||
req.tokenizer = self.tokenizer
|
)
|
||||||
|
req.tokenizer = self.tokenizer
|
||||||
|
if recv_req.session_id is not None:
|
||||||
|
req.finished_reason = FINISH_ABORT(
|
||||||
|
f"Invalid request: session id {recv_req.session_id} does not exist"
|
||||||
|
)
|
||||||
|
self.waiting_queue.append(req)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Handle sessions
|
||||||
|
session = self.sessions[recv_req.session_id]
|
||||||
|
req, new_session_id = session.create_req(recv_req, self.tokenizer)
|
||||||
|
del self.sessions[recv_req.session_id]
|
||||||
|
self.sessions[new_session_id] = session
|
||||||
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||||
|
self.waiting_queue.append(req)
|
||||||
|
return
|
||||||
|
|
||||||
# Image inputs
|
# Image inputs
|
||||||
if recv_req.image_inputs is not None:
|
if recv_req.image_inputs is not None:
|
||||||
@@ -1151,6 +1179,7 @@ class Scheduler:
|
|||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_spaces_between_special_tokens = []
|
output_spaces_between_special_tokens = []
|
||||||
output_no_stop_trim = []
|
output_no_stop_trim = []
|
||||||
|
output_session_ids = []
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
output_embeddings = []
|
output_embeddings = []
|
||||||
|
|
||||||
@@ -1178,6 +1207,7 @@ class Scheduler:
|
|||||||
req.sampling_params.spaces_between_special_tokens
|
req.sampling_params.spaces_between_special_tokens
|
||||||
)
|
)
|
||||||
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||||
|
output_session_ids.append(req.session_id)
|
||||||
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"prompt_tokens": len(req.origin_input_ids),
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
@@ -1228,6 +1258,7 @@ class Scheduler:
|
|||||||
output_meta_info,
|
output_meta_info,
|
||||||
output_finished_reason,
|
output_finished_reason,
|
||||||
output_no_stop_trim,
|
output_no_stop_trim,
|
||||||
|
output_session_ids,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
@@ -1330,6 +1361,25 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
logger.info("Profiler is done")
|
logger.info("Profiler is done")
|
||||||
|
|
||||||
|
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
||||||
|
# handle error
|
||||||
|
session_id = recv_req.session_id
|
||||||
|
if session_id in self.sessions:
|
||||||
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
||||||
|
else:
|
||||||
|
self.sessions[session_id] = Session(
|
||||||
|
recv_req.capacity_of_str_len, session_id
|
||||||
|
)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
def close_session(self, recv_req: CloseSessionReqInput):
|
||||||
|
# handle error
|
||||||
|
session_id = recv_req.session_id
|
||||||
|
if session_id not in self.sessions:
|
||||||
|
logger.warning(f"session id {session_id} does not exist, cannot delete.")
|
||||||
|
else:
|
||||||
|
del self.sessions[session_id]
|
||||||
|
|
||||||
|
|
||||||
def run_scheduler_process(
|
def run_scheduler_process(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
|
|||||||
62
python/sglang/srt/managers/session_controller.py
Normal file
62
python/sglang/srt/managers/session_controller.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||||
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
||||||
|
|
||||||
|
|
||||||
|
class Session:
|
||||||
|
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
||||||
|
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
||||||
|
self.capacity_of_str_len = capacity_of_str_len
|
||||||
|
self.reqs: List[Req] = []
|
||||||
|
|
||||||
|
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
||||||
|
# renew session id
|
||||||
|
self.session_id = uuid.uuid4().hex
|
||||||
|
if req.session_rid is not None:
|
||||||
|
while len(self.reqs) > 0:
|
||||||
|
if self.reqs[-1].rid == req.session_rid:
|
||||||
|
break
|
||||||
|
self.reqs = self.reqs[:-1]
|
||||||
|
if len(self.reqs) > 0:
|
||||||
|
input_ids = (
|
||||||
|
self.reqs[-1].origin_input_ids
|
||||||
|
+ self.reqs[-1].output_ids[
|
||||||
|
: self.reqs[-1].sampling_params.max_new_tokens
|
||||||
|
]
|
||||||
|
+ req.input_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_ids = req.input_ids
|
||||||
|
new_req = Req(
|
||||||
|
req.rid,
|
||||||
|
None,
|
||||||
|
input_ids,
|
||||||
|
req.sampling_params,
|
||||||
|
lora_path=req.lora_path,
|
||||||
|
session_id=self.session_id,
|
||||||
|
)
|
||||||
|
new_req.tokenizer = tokenizer
|
||||||
|
if req.session_rid is not None and len(self.reqs) == 0:
|
||||||
|
new_req.finished_reason = FINISH_ABORT(
|
||||||
|
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reqs.append(new_req)
|
||||||
|
return new_req, self.session_id
|
||||||
@@ -23,6 +23,7 @@ import os
|
|||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
|
CloseSessionReqInput,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetMemPoolSizeReq,
|
GetMemPoolSizeReq,
|
||||||
GetMemPoolSizeReqOutput,
|
GetMemPoolSizeReqOutput,
|
||||||
|
OpenSessionReqInput,
|
||||||
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -146,6 +150,9 @@ class TokenizerManager:
|
|||||||
self.model_update_lock = asyncio.Lock()
|
self.model_update_lock = asyncio.Lock()
|
||||||
self.model_update_result = None
|
self.model_update_result = None
|
||||||
|
|
||||||
|
# For session info
|
||||||
|
self.session_futures = {} # session_id -> asyncio event
|
||||||
|
|
||||||
# Others
|
# Others
|
||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
|
|
||||||
@@ -211,6 +218,8 @@ class TokenizerManager:
|
|||||||
return_logprob = obj.return_logprob
|
return_logprob = obj.return_logprob
|
||||||
logprob_start_len = obj.logprob_start_len
|
logprob_start_len = obj.logprob_start_len
|
||||||
top_logprobs_num = obj.top_logprobs_num
|
top_logprobs_num = obj.top_logprobs_num
|
||||||
|
session_id = obj.session_id
|
||||||
|
session_rid = obj.session_rid
|
||||||
|
|
||||||
if len(input_ids) >= self.context_len:
|
if len(input_ids) >= self.context_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -236,6 +245,8 @@ class TokenizerManager:
|
|||||||
top_logprobs_num,
|
top_logprobs_num,
|
||||||
obj.stream,
|
obj.stream,
|
||||||
obj.lora_path,
|
obj.lora_path,
|
||||||
|
session_id=session_id,
|
||||||
|
session_rid=session_rid,
|
||||||
)
|
)
|
||||||
elif isinstance(obj, EmbeddingReqInput):
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
@@ -451,6 +462,26 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
return False, "Another update is in progress. Please try again later."
|
return False, "Another update is in progress. Please try again later."
|
||||||
|
|
||||||
|
async def open_session(
|
||||||
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||||
|
):
|
||||||
|
if self.to_create_loop:
|
||||||
|
self.create_handle_loop()
|
||||||
|
|
||||||
|
session_id = uuid.uuid4().hex
|
||||||
|
obj.session_id = session_id
|
||||||
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
self.session_futures[session_id] = asyncio.Future()
|
||||||
|
session_id = await self.session_futures[session_id]
|
||||||
|
del self.session_futures[session_id]
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
async def close_session(
|
||||||
|
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
||||||
|
):
|
||||||
|
assert not self.to_create_loop, "close session should not be the first request"
|
||||||
|
await self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
|
||||||
def create_abort_task(self, obj: GenerateReqInput):
|
def create_abort_task(self, obj: GenerateReqInput):
|
||||||
# Abort the request if the client is disconnected.
|
# Abort the request if the client is disconnected.
|
||||||
async def abort_request():
|
async def abort_request():
|
||||||
@@ -521,6 +552,11 @@ class TokenizerManager:
|
|||||||
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
||||||
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||||
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
|
recv_obj.session_id
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||||
@@ -536,11 +572,13 @@ class TokenizerManager:
|
|||||||
out_dict = {
|
out_dict = {
|
||||||
"text": recv_obj.output_strs[i],
|
"text": recv_obj.output_strs[i],
|
||||||
"meta_info": recv_obj.meta_info[i],
|
"meta_info": recv_obj.meta_info[i],
|
||||||
|
"session_id": recv_obj.session_ids[i],
|
||||||
}
|
}
|
||||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"token_ids": recv_obj.output_ids[i],
|
"token_ids": recv_obj.output_ids[i],
|
||||||
"meta_info": recv_obj.meta_info[i],
|
"meta_info": recv_obj.meta_info[i],
|
||||||
|
"session_id": recv_obj.session_ids[i],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||||
|
|||||||
@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
CloseSessionReqInput,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
|
OpenSessionReqInput,
|
||||||
UpdateWeightReqInput,
|
UpdateWeightReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||||
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||||
|
"""Open a session, and return its unique session id."""
|
||||||
|
try:
|
||||||
|
session_id = await tokenizer_manager.open_session(obj, request)
|
||||||
|
return session_id
|
||||||
|
except Exception as e:
|
||||||
|
return ORJSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||||
|
async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||||
|
"""Close the session"""
|
||||||
|
try:
|
||||||
|
await tokenizer_manager.close_session(obj, request)
|
||||||
|
return Response(status_code=200)
|
||||||
|
except Exception as e:
|
||||||
|
return ORJSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@time_func_latency
|
@time_func_latency
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
"""Handle a generate request."""
|
"""Handle a generate request."""
|
||||||
|
|||||||
133
test/srt/test_session_id.py
Normal file
133
test/srt/test_session_id.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# FIXME: Make it a CI test
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
|
||||||
|
url = "http://localhost:30000"
|
||||||
|
|
||||||
|
# Open a session
|
||||||
|
response = requests.post(
|
||||||
|
url + "/open_session",
|
||||||
|
json={"capacity_of_str_len": 1000},
|
||||||
|
)
|
||||||
|
session_id = response.json()
|
||||||
|
print("session_id", session_id, "\n")
|
||||||
|
|
||||||
|
# Prefill only
|
||||||
|
prompt = "chunk 1"
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"session_id": session_id,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json(), "\n")
|
||||||
|
session_id = response.json()["session_id"]
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
prompt = "Chunk 2"
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"session_id": session_id,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json(), "\n")
|
||||||
|
session_id = response.json()["session_id"]
|
||||||
|
rid = response.json()["meta_info"]["id"]
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
prompt = "Chunk 3"
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"session_id": session_id,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json(), "\n")
|
||||||
|
session_id = response.json()["session_id"]
|
||||||
|
rid_to_del = response.json()["meta_info"]["id"]
|
||||||
|
|
||||||
|
# Interrupt and re-generate
|
||||||
|
prompt = "Chunk 4"
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"session_id": session_id,
|
||||||
|
"session_rid": rid,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json(), "\n")
|
||||||
|
session_id = response.json()["session_id"]
|
||||||
|
|
||||||
|
# Query a session based on a deleted request, should see finish reason abort
|
||||||
|
prompt = "Chunk 4"
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"session_id": session_id,
|
||||||
|
"session_rid": rid_to_del,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json(), "\n")
|
||||||
|
|
||||||
|
# Close session
|
||||||
|
ret = requests.post(
|
||||||
|
url + "/close_session",
|
||||||
|
json={"session_id": session_id},
|
||||||
|
)
|
||||||
|
print(ret, "\n")
|
||||||
|
|
||||||
|
# Query a deleted session, should see finish reason abort
|
||||||
|
prompt = "chunk 1"
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"session_id": session_id,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json(), "\n")
|
||||||
Reference in New Issue
Block a user