Add a new api configure_logging to allow dumping the requests (#2875)

This commit is contained in:
Lianmin Zheng
2025-01-13 14:24:00 -08:00
committed by GitHub
parent 923f518337
commit 46d4431889
13 changed files with 164 additions and 71 deletions

View File

@@ -31,7 +31,7 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
import torch
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -54,6 +54,7 @@ from sglang.srt.managers.data_parallel_controller import (
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
@@ -161,12 +162,68 @@ async def get_model_info():
@app.get("/get_server_info")
async def get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**dataclasses.asdict(tokenizer_manager.server_args),
**scheduler_info,
"version": __version__,
}
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
async def stream_results() -> AsyncIterator[bytes]:
try:
async for out in tokenizer_manager.generate_request(obj, request):
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
yield b"data: [DONE]\n\n"
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/classify", methods=["POST", "PUT"])
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.post("/flush_cache")
async def flush_cache():
"""Flush the radix cache."""
@@ -178,8 +235,7 @@ async def flush_cache():
)
@app.get("/start_profile")
@app.post("/start_profile")
@app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async():
"""Start profiling."""
tokenizer_manager.start_profile()
@@ -189,8 +245,7 @@ async def start_profile_async():
)
@app.get("/stop_profile")
@app.post("/stop_profile")
@app.api_route("/stop_profile", methods=["GET", "POST"])
async def stop_profile_async():
"""Stop profiling."""
tokenizer_manager.stop_profile()
@@ -305,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
return _create_error_response(e)
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
async def stream_results() -> AsyncIterator[bytes]:
try:
async for out in tokenizer_manager.generate_request(obj, request):
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
yield b"data: [DONE]\n\n"
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/classify", methods=["POST", "PUT"])
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
"""Close the session"""
tokenizer_manager.configure_logging(obj)
return Response(status_code=200)
##### OpenAI-compatible API endpoints #####