Add a new api configure_logging to allow dumping the requests (#2875)
This commit is contained in:
@@ -31,7 +31,7 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -54,6 +54,7 @@ from sglang.srt.managers.data_parallel_controller import (
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
@@ -161,12 +162,68 @@ async def get_model_info():
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
**dataclasses.asdict(tokenizer_manager.server_args),
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
async for out in tokenizer_manager.generate_request(obj, request):
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
except ValueError as e:
|
||||
out = {"error": {"message": str(e)}}
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(obj),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
"""Flush the radix cache."""
|
||||
@@ -178,8 +235,7 @@ async def flush_cache():
|
||||
)
|
||||
|
||||
|
||||
@app.get("/start_profile")
|
||||
@app.post("/start_profile")
|
||||
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||
async def start_profile_async():
|
||||
"""Start profiling."""
|
||||
tokenizer_manager.start_profile()
|
||||
@@ -189,8 +245,7 @@ async def start_profile_async():
|
||||
)
|
||||
|
||||
|
||||
@app.get("/stop_profile")
|
||||
@app.post("/stop_profile")
|
||||
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
||||
async def stop_profile_async():
|
||||
"""Stop profiling."""
|
||||
tokenizer_manager.stop_profile()
|
||||
@@ -305,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
async for out in tokenizer_manager.generate_request(obj, request):
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
except ValueError as e:
|
||||
out = {"error": {"message": str(e)}}
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(obj),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
"""Close the session"""
|
||||
tokenizer_manager.configure_logging(obj)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
Reference in New Issue
Block a user