Add a new api configure_logging to allow dumping the requests (#2875)
This commit is contained in:
2
3rdparty/amd/profiling/PROFILING.md
vendored
2
3rdparty/amd/profiling/PROFILING.md
vendored
@@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \
|
|||||||
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
||||||
--tokenizer-path Xenova/grok-1-tokenizer \
|
--tokenizer-path Xenova/grok-1-tokenizer \
|
||||||
--load-format dummy \
|
--load-format dummy \
|
||||||
--quant fp8 \
|
--quantization fp8 \
|
||||||
--tp 8 \
|
--tp 8 \
|
||||||
--port 30000 \
|
--port 30000 \
|
||||||
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
||||||
|
|||||||
2
3rdparty/amd/profiling/server.sh
vendored
2
3rdparty/amd/profiling/server.sh
vendored
@@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \
|
|||||||
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
||||||
--tokenizer-path Xenova/grok-1-tokenizer \
|
--tokenizer-path Xenova/grok-1-tokenizer \
|
||||||
--load-format dummy \
|
--load-format dummy \
|
||||||
--quant fp8 \
|
--quantization fp8 \
|
||||||
--tp 8 \
|
--tp 8 \
|
||||||
--port 30000 \
|
--port 30000 \
|
||||||
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
||||||
|
|||||||
2
3rdparty/amd/tuning/TUNING.md
vendored
2
3rdparty/amd/tuning/TUNING.md
vendored
@@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
#Tuning
|
#Tuning
|
||||||
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
|
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
|
||||||
#so we can tune decode moe use below command
|
#so we can tune decode moe use below command
|
||||||
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
|
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
|
||||||
# and use this command to tune prefill moe
|
# and use this command to tune prefill moe
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
||||||
|
|
||||||
# Launch sglang
|
# Launch sglang
|
||||||
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
|
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87
|
||||||
|
|
||||||
# offline
|
# offline
|
||||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
|
||||||
|
|||||||
43
python/sglang/srt/managers/configure_logging.py
Normal file
43
python/sglang/srt/managers/configure_logging.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2025 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Configure the logging settings of a server.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
||||||
|
)
|
||||||
|
parser.add_argument("--dump-requests-threshold", type=int, default=1000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
args.url + "/configure_logging",
|
||||||
|
json={
|
||||||
|
"dump_requests_folder": args.dump_requests_folder,
|
||||||
|
"dump_requests_threshold": args.dump_requests_threshold,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
@@ -488,6 +488,13 @@ class ProfileReq(Enum):
|
|||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfigureLoggingReq:
|
||||||
|
log_requests: Optional[bool] = None
|
||||||
|
dump_requests_folder: Optional[str] = None
|
||||||
|
dump_requests_threshold: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenSessionReqInput:
|
class OpenSessionReqInput:
|
||||||
capacity_of_str_len: int
|
capacity_of_str_len: int
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
@@ -92,7 +93,6 @@ from sglang.srt.utils import (
|
|||||||
set_random_seed,
|
set_random_seed,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
)
|
)
|
||||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -18,10 +18,12 @@ import copy
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
@@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
|
ConfigureLoggingReq,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
@@ -109,6 +112,7 @@ class TokenizerManager:
|
|||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.enable_metrics = server_args.enable_metrics
|
self.enable_metrics = server_args.enable_metrics
|
||||||
|
self.log_requests = server_args.log_requests
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.asyncio.Context(2)
|
||||||
@@ -167,6 +171,9 @@ class TokenizerManager:
|
|||||||
# Store states
|
# Store states
|
||||||
self.to_create_loop = True
|
self.to_create_loop = True
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
|
self.dump_requests_threshold = 1000
|
||||||
|
self.dump_request_list: List[Tuple] = []
|
||||||
|
|
||||||
# The event to notify the weight sync is finished.
|
# The event to notify the weight sync is finished.
|
||||||
self.model_update_lock = RWLock()
|
self.model_update_lock = RWLock()
|
||||||
@@ -225,7 +232,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
obj.normalize_batch_and_arguments()
|
obj.normalize_batch_and_arguments()
|
||||||
|
|
||||||
if self.server_args.log_requests:
|
if self.log_requests:
|
||||||
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
||||||
|
|
||||||
async with self.model_update_lock.reader_lock:
|
async with self.model_update_lock.reader_lock:
|
||||||
@@ -346,7 +353,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
state.out_list = []
|
state.out_list = []
|
||||||
if state.finished:
|
if state.finished:
|
||||||
if self.server_args.log_requests:
|
if self.log_requests:
|
||||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
del self.rid_to_state[obj.rid]
|
del self.rid_to_state[obj.rid]
|
||||||
@@ -597,6 +604,15 @@ class TokenizerManager:
|
|||||||
assert not self.to_create_loop, "close session should not be the first request"
|
assert not self.to_create_loop, "close session should not be the first request"
|
||||||
await self.send_to_scheduler.send_pyobj(obj)
|
await self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
|
||||||
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||||
|
if obj.log_requests is not None:
|
||||||
|
self.log_requests = obj.log_requests
|
||||||
|
if obj.dump_requests_folder is not None:
|
||||||
|
self.dump_requests_folder = obj.dump_requests_folder
|
||||||
|
if obj.dump_requests_threshold is not None:
|
||||||
|
self.dump_requests_threshold = obj.dump_requests_threshold
|
||||||
|
logging.info(f"Config logging: {obj=}")
|
||||||
|
|
||||||
def create_abort_task(self, obj: GenerateReqInput):
|
def create_abort_task(self, obj: GenerateReqInput):
|
||||||
# Abort the request if the client is disconnected.
|
# Abort the request if the client is disconnected.
|
||||||
async def abort_request():
|
async def abort_request():
|
||||||
@@ -708,6 +724,8 @@ class TokenizerManager:
|
|||||||
|
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
self.collect_metrics(state, recv_obj, i)
|
self.collect_metrics(state, recv_obj, i)
|
||||||
|
if self.dump_requests_folder and state.finished:
|
||||||
|
self.dump_requests(state, out_dict)
|
||||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
recv_obj.session_id if recv_obj.success else None
|
recv_obj.session_id if recv_obj.success else None
|
||||||
@@ -850,6 +868,25 @@ class TokenizerManager:
|
|||||||
(time.time() - state.created_time) / completion_tokens
|
(time.time() - state.created_time) / completion_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def dump_requests(self, state: ReqState, out_dict: dict):
|
||||||
|
self.dump_request_list.append(
|
||||||
|
(state.obj, out_dict, state.created_time, time.time())
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
||||||
|
to_dump = self.dump_request_list
|
||||||
|
self.dump_request_list = []
|
||||||
|
|
||||||
|
def background_task():
|
||||||
|
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
||||||
|
current_time = datetime.now()
|
||||||
|
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
|
||||||
|
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
|
||||||
|
pickle.dump(to_dump, f)
|
||||||
|
|
||||||
|
# Schedule the task to run in the background without awaiting it
|
||||||
|
asyncio.create_task(asyncio.to_thread(background_task))
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
def __init__(self, tokenizer_manager):
|
def __init__(self, tokenizer_manager):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Memory pool.
|
Memory pool.
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
@@ -60,7 +61,6 @@ from sglang.srt.utils import (
|
|||||||
monkey_patch_vllm_p2p_access_check,
|
monkey_patch_vllm_p2p_access_check,
|
||||||
set_cpu_offload_max_bytes,
|
set_cpu_offload_max_bytes,
|
||||||
)
|
)
|
||||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
@@ -54,6 +54,7 @@ from sglang.srt.managers.data_parallel_controller import (
|
|||||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
|
ConfigureLoggingReq,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
@@ -161,12 +162,68 @@ async def get_model_info():
|
|||||||
@app.get("/get_server_info")
|
@app.get("/get_server_info")
|
||||||
async def get_server_info():
|
async def get_server_info():
|
||||||
return {
|
return {
|
||||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
**dataclasses.asdict(tokenizer_manager.server_args),
|
||||||
**scheduler_info,
|
**scheduler_info,
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||||
|
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||||
|
@time_func_latency
|
||||||
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
|
"""Handle a generate request."""
|
||||||
|
if obj.stream:
|
||||||
|
|
||||||
|
async def stream_results() -> AsyncIterator[bytes]:
|
||||||
|
try:
|
||||||
|
async for out in tokenizer_manager.generate_request(obj, request):
|
||||||
|
yield b"data: " + orjson.dumps(
|
||||||
|
out, option=orjson.OPT_NON_STR_KEYS
|
||||||
|
) + b"\n\n"
|
||||||
|
except ValueError as e:
|
||||||
|
out = {"error": {"message": str(e)}}
|
||||||
|
yield b"data: " + orjson.dumps(
|
||||||
|
out, option=orjson.OPT_NON_STR_KEYS
|
||||||
|
) + b"\n\n"
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_results(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
background=tokenizer_manager.create_abort_task(obj),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||||
|
@time_func_latency
|
||||||
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||||
|
"""Handle an embedding request."""
|
||||||
|
try:
|
||||||
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||||
|
@time_func_latency
|
||||||
|
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||||
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||||
|
try:
|
||||||
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/flush_cache")
|
@app.post("/flush_cache")
|
||||||
async def flush_cache():
|
async def flush_cache():
|
||||||
"""Flush the radix cache."""
|
"""Flush the radix cache."""
|
||||||
@@ -178,8 +235,7 @@ async def flush_cache():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/start_profile")
|
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||||
@app.post("/start_profile")
|
|
||||||
async def start_profile_async():
|
async def start_profile_async():
|
||||||
"""Start profiling."""
|
"""Start profiling."""
|
||||||
tokenizer_manager.start_profile()
|
tokenizer_manager.start_profile()
|
||||||
@@ -189,8 +245,7 @@ async def start_profile_async():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stop_profile")
|
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
||||||
@app.post("/stop_profile")
|
|
||||||
async def stop_profile_async():
|
async def stop_profile_async():
|
||||||
"""Stop profiling."""
|
"""Stop profiling."""
|
||||||
tokenizer_manager.stop_profile()
|
tokenizer_manager.stop_profile()
|
||||||
@@ -305,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||||
@time_func_latency
|
"""Close the session"""
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
tokenizer_manager.configure_logging(obj)
|
||||||
"""Handle a generate request."""
|
return Response(status_code=200)
|
||||||
if obj.stream:
|
|
||||||
|
|
||||||
async def stream_results() -> AsyncIterator[bytes]:
|
|
||||||
try:
|
|
||||||
async for out in tokenizer_manager.generate_request(obj, request):
|
|
||||||
yield b"data: " + orjson.dumps(
|
|
||||||
out, option=orjson.OPT_NON_STR_KEYS
|
|
||||||
) + b"\n\n"
|
|
||||||
except ValueError as e:
|
|
||||||
out = {"error": {"message": str(e)}}
|
|
||||||
yield b"data: " + orjson.dumps(
|
|
||||||
out, option=orjson.OPT_NON_STR_KEYS
|
|
||||||
) + b"\n\n"
|
|
||||||
yield b"data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_results(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
background=tokenizer_manager.create_abort_task(obj),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
|
||||||
return ret
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Error: {e}")
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
|
||||||
@time_func_latency
|
|
||||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|
||||||
"""Handle an embedding request."""
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
|
||||||
return ret
|
|
||||||
except ValueError as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
|
||||||
@time_func_latency
|
|
||||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
||||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
|
||||||
return ret
|
|
||||||
except ValueError as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
##### OpenAI-compatible API endpoints #####
|
##### OpenAI-compatible API endpoints #####
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
# API related
|
# API related
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
file_storage_pth: str = "SGLang_storage"
|
file_storage_pth: str = "sglang_storage"
|
||||||
enable_cache_report: bool = False
|
enable_cache_report: bool = False
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
@@ -554,7 +554,7 @@ class ServerArgs:
|
|||||||
"--decode-log-interval",
|
"--decode-log-interval",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.decode_log_interval,
|
default=ServerArgs.decode_log_interval,
|
||||||
help="The log interval of decode batch",
|
help="The log interval of decode batch.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# API related
|
# API related
|
||||||
|
|||||||
Reference in New Issue
Block a user