[feat] Add session control (#2073)

This commit is contained in:
Ying Sheng
2024-11-20 00:36:53 -08:00
committed by GitHub
parent 63a395b985
commit 5942dfc00a
8 changed files with 348 additions and 8 deletions

View File

@@ -23,6 +23,7 @@ import os
import signal
import sys
import time
import uuid
from typing import Dict, List, Optional, Tuple, Union
import fastapi
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
CloseSessionReqInput,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
@@ -146,6 +150,9 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
# For session info
self.session_futures = {} # session_id -> asyncio event
# Others
self.gracefully_exit = False
@@ -211,6 +218,8 @@ class TokenizerManager:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_id = obj.session_id
session_rid = obj.session_rid
if len(input_ids) >= self.context_len:
raise ValueError(
@@ -236,6 +245,8 @@ class TokenizerManager:
top_logprobs_num,
obj.stream,
obj.lora_path,
session_id=session_id,
session_rid=session_rid,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
@@ -451,6 +462,26 @@ class TokenizerManager:
else:
return False, "Another update is in progress. Please try again later."
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
session_id = uuid.uuid4().hex
obj.session_id = session_id
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[session_id] = asyncio.Future()
session_id = await self.session_futures[session_id]
del self.session_futures[session_id]
return session_id
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj)
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
@@ -521,6 +552,11 @@ class TokenizerManager:
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
continue
assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
@@ -536,11 +572,13 @@ class TokenizerManager:
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)