[feat] Add session control (#2073)

This commit is contained in:
Ying Sheng
2024-11-20 00:36:53 -08:00
committed by GitHub
parent 63a395b985
commit 5942dfc00a
8 changed files with 348 additions and 8 deletions

View File

@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
)
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
EmbeddingReqInput,
GenerateReqInput,
OpenSessionReqInput,
UpdateWeightReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
)
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
try:
session_id = await tokenizer_manager.open_session(obj, request)
return session_id
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session"""
try:
await tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""