Add a new api configure_logging to allow dumping the requests (#2875)
This commit is contained in:
2
3rdparty/amd/profiling/PROFILING.md
vendored
2
3rdparty/amd/profiling/PROFILING.md
vendored
@@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \
|
||||
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
||||
--tokenizer-path Xenova/grok-1-tokenizer \
|
||||
--load-format dummy \
|
||||
--quant fp8 \
|
||||
--quantization fp8 \
|
||||
--tp 8 \
|
||||
--port 30000 \
|
||||
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
||||
|
||||
2
3rdparty/amd/profiling/server.sh
vendored
2
3rdparty/amd/profiling/server.sh
vendored
@@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \
|
||||
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
||||
--tokenizer-path Xenova/grok-1-tokenizer \
|
||||
--load-format dummy \
|
||||
--quant fp8 \
|
||||
--quantization fp8 \
|
||||
--tp 8 \
|
||||
--port 30000 \
|
||||
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
||||
|
||||
2
3rdparty/amd/tuning/TUNING.md
vendored
2
3rdparty/amd/tuning/TUNING.md
vendored
@@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes
|
||||
|
||||
```bash
|
||||
#Tuning
|
||||
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
|
||||
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
|
||||
#so we can tune decode moe use below command
|
||||
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
|
||||
# and use this command to tune prefill moe
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
||||
|
||||
# Launch sglang
|
||||
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
|
||||
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87
|
||||
|
||||
# offline
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
|
||||
|
||||
43
python/sglang/srt/managers/configure_logging.py
Normal file
43
python/sglang/srt/managers/configure_logging.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Copyright 2023-2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Configure the logging settings of a server.
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
||||
parser.add_argument(
|
||||
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
||||
)
|
||||
parser.add_argument("--dump-requests-threshold", type=int, default=1000)
|
||||
args = parser.parse_args()
|
||||
|
||||
response = requests.post(
|
||||
args.url + "/configure_logging",
|
||||
json={
|
||||
"dump_requests_folder": args.dump_requests_folder,
|
||||
"dump_requests_threshold": args.dump_requests_threshold,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -488,6 +488,13 @@ class ProfileReq(Enum):
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigureLoggingReq:
|
||||
log_requests: Optional[bool] = None
|
||||
dump_requests_folder: Optional[str] = None
|
||||
dump_requests_threshold: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenSessionReqInput:
|
||||
capacity_of_str_len: int
|
||||
|
||||
@@ -82,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
broadcast_pyobj,
|
||||
configure_logger,
|
||||
@@ -92,7 +93,6 @@ from sglang.srt.utils import (
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,10 +18,12 @@ import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import fastapi
|
||||
@@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
@@ -109,6 +112,7 @@ class TokenizerManager:
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
self.log_requests = server_args.log_requests
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
@@ -167,6 +171,9 @@ class TokenizerManager:
|
||||
# Store states
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
self.dump_requests_threshold = 1000
|
||||
self.dump_request_list: List[Tuple] = []
|
||||
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
@@ -225,7 +232,7 @@ class TokenizerManager:
|
||||
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if self.server_args.log_requests:
|
||||
if self.log_requests:
|
||||
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
||||
|
||||
async with self.model_update_lock.reader_lock:
|
||||
@@ -346,7 +353,7 @@ class TokenizerManager:
|
||||
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
if self.server_args.log_requests:
|
||||
if self.log_requests:
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
||||
logger.info(msg)
|
||||
del self.rid_to_state[obj.rid]
|
||||
@@ -597,6 +604,15 @@ class TokenizerManager:
|
||||
assert not self.to_create_loop, "close session should not be the first request"
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||
if obj.log_requests is not None:
|
||||
self.log_requests = obj.log_requests
|
||||
if obj.dump_requests_folder is not None:
|
||||
self.dump_requests_folder = obj.dump_requests_folder
|
||||
if obj.dump_requests_threshold is not None:
|
||||
self.dump_requests_threshold = obj.dump_requests_threshold
|
||||
logging.info(f"Config logging: {obj=}")
|
||||
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
async def abort_request():
|
||||
@@ -708,6 +724,8 @@ class TokenizerManager:
|
||||
|
||||
if self.enable_metrics:
|
||||
self.collect_metrics(state, recv_obj, i)
|
||||
if self.dump_requests_folder and state.finished:
|
||||
self.dump_requests(state, out_dict)
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id if recv_obj.success else None
|
||||
@@ -850,6 +868,25 @@ class TokenizerManager:
|
||||
(time.time() - state.created_time) / completion_tokens
|
||||
)
|
||||
|
||||
def dump_requests(self, state: ReqState, out_dict: dict):
|
||||
self.dump_request_list.append(
|
||||
(state.obj, out_dict, state.created_time, time.time())
|
||||
)
|
||||
|
||||
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
||||
to_dump = self.dump_request_list
|
||||
self.dump_request_list = []
|
||||
|
||||
def background_task():
|
||||
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
||||
current_time = datetime.now()
|
||||
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
|
||||
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
|
||||
pickle.dump(to_dump, f)
|
||||
|
||||
# Schedule the task to run in the background without awaiting it
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
def __init__(self, tokenizer_manager):
|
||||
|
||||
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
"""
|
||||
Memory pool.
|
||||
|
||||
@@ -50,6 +50,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
@@ -60,7 +61,6 @@ from sglang.srt.utils import (
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
set_cpu_offload_max_bytes,
|
||||
)
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -54,6 +54,7 @@ from sglang.srt.managers.data_parallel_controller import (
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
@@ -161,12 +162,68 @@ async def get_model_info():
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
**dataclasses.asdict(tokenizer_manager.server_args),
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
async for out in tokenizer_manager.generate_request(obj, request):
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
except ValueError as e:
|
||||
out = {"error": {"message": str(e)}}
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(obj),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
"""Flush the radix cache."""
|
||||
@@ -178,8 +235,7 @@ async def flush_cache():
|
||||
)
|
||||
|
||||
|
||||
@app.get("/start_profile")
|
||||
@app.post("/start_profile")
|
||||
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||
async def start_profile_async():
|
||||
"""Start profiling."""
|
||||
tokenizer_manager.start_profile()
|
||||
@@ -189,8 +245,7 @@ async def start_profile_async():
|
||||
)
|
||||
|
||||
|
||||
@app.get("/stop_profile")
|
||||
@app.post("/stop_profile")
|
||||
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
||||
async def stop_profile_async():
|
||||
"""Stop profiling."""
|
||||
tokenizer_manager.stop_profile()
|
||||
@@ -305,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
async for out in tokenizer_manager.generate_request(obj, request):
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
except ValueError as e:
|
||||
out = {"error": {"message": str(e)}}
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
background=tokenizer_manager.create_abort_task(obj),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
"""Close the session"""
|
||||
tokenizer_manager.configure_logging(obj)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
@@ -91,7 +91,7 @@ class ServerArgs:
|
||||
|
||||
# API related
|
||||
api_key: Optional[str] = None
|
||||
file_storage_pth: str = "SGLang_storage"
|
||||
file_storage_pth: str = "sglang_storage"
|
||||
enable_cache_report: bool = False
|
||||
|
||||
# Data parallelism
|
||||
@@ -554,7 +554,7 @@ class ServerArgs:
|
||||
"--decode-log-interval",
|
||||
type=int,
|
||||
default=ServerArgs.decode_log_interval,
|
||||
help="The log interval of decode batch",
|
||||
help="The log interval of decode batch.",
|
||||
)
|
||||
|
||||
# API related
|
||||
|
||||
Reference in New Issue
Block a user