[feat] Add session control (#2073)
This commit is contained in:
@@ -23,6 +23,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetMemPoolSizeReq,
|
||||
GetMemPoolSizeReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -146,6 +150,9 @@ class TokenizerManager:
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
|
||||
# For session info
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
|
||||
# Others
|
||||
self.gracefully_exit = False
|
||||
|
||||
@@ -211,6 +218,8 @@ class TokenizerManager:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
session_id = obj.session_id
|
||||
session_rid = obj.session_rid
|
||||
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
@@ -236,6 +245,8 @@ class TokenizerManager:
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
obj.lora_path,
|
||||
session_id=session_id,
|
||||
session_rid=session_rid,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -451,6 +462,26 @@ class TokenizerManager:
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
session_id = uuid.uuid4().hex
|
||||
obj.session_id = session_id
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.session_futures[session_id] = asyncio.Future()
|
||||
session_id = await self.session_futures[session_id]
|
||||
del self.session_futures[session_id]
|
||||
return session_id
|
||||
|
||||
async def close_session(
|
||||
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
assert not self.to_create_loop, "close session should not be the first request"
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
async def abort_request():
|
||||
@@ -521,6 +552,11 @@ class TokenizerManager:
|
||||
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
||||
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
)
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||
@@ -536,11 +572,13 @@ class TokenizerManager:
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"session_id": recv_obj.session_ids[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"session_id": recv_obj.session_ids[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
|
||||
Reference in New Issue
Block a user