From 46d44318894a13dc6d018892b32dd4a7e09f20f7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 13 Jan 2025 14:24:00 -0800 Subject: [PATCH] Add a new api configure_logging to allow dumping the requests (#2875) --- 3rdparty/amd/profiling/PROFILING.md | 2 +- 3rdparty/amd/profiling/server.sh | 2 +- 3rdparty/amd/tuning/TUNING.md | 2 +- benchmark/blog_v0_2/405b_sglang.sh | 2 +- .../sglang/srt/managers/configure_logging.py | 43 ++++++ python/sglang/srt/managers/io_struct.py | 7 + python/sglang/srt/managers/scheduler.py | 2 +- .../sglang/srt/managers/tokenizer_manager.py | 41 +++++- python/sglang/srt/mem_cache/memory_pool.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/server.py | 126 +++++++++--------- python/sglang/srt/server_args.py | 4 +- .../{ => srt}/torch_memory_saver_adapter.py | 0 13 files changed, 164 insertions(+), 71 deletions(-) create mode 100644 python/sglang/srt/managers/configure_logging.py rename python/sglang/{ => srt}/torch_memory_saver_adapter.py (100%) diff --git a/3rdparty/amd/profiling/PROFILING.md b/3rdparty/amd/profiling/PROFILING.md index 79bc75b50..7e15ec844 100644 --- a/3rdparty/amd/profiling/PROFILING.md +++ b/3rdparty/amd/profiling/PROFILING.md @@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \ --model-path /sgl-workspace/sglang/dummy_grok1 \ --tokenizer-path Xenova/grok-1-tokenizer \ --load-format dummy \ - --quant fp8 \ + --quantization fp8 \ --tp 8 \ --port 30000 \ --disable-radix-cache 2>&1 | tee "$LOGFILE" diff --git a/3rdparty/amd/profiling/server.sh b/3rdparty/amd/profiling/server.sh index aa574f64c..f877e6c7a 100755 --- a/3rdparty/amd/profiling/server.sh +++ b/3rdparty/amd/profiling/server.sh @@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \ --model-path /sgl-workspace/sglang/dummy_grok1 \ --tokenizer-path Xenova/grok-1-tokenizer \ --load-format dummy \ - --quant fp8 \ + --quantization fp8 \ --tp 8 \ --port 30000 \ --disable-radix-cache 2>&1 | tee "$LOGFILE" diff --git a/3rdparty/amd/tuning/TUNING.md b/3rdparty/amd/tuning/TUNING.md index a38a16d4f..0638041c9 100644 --- a/3rdparty/amd/tuning/TUNING.md +++ b/3rdparty/amd/tuning/TUNING.md @@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes ```bash #Tuning -#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run). +#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run). #so we can tune decode moe use below command python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32" # and use this command to tune prefill moe diff --git a/benchmark/blog_v0_2/405b_sglang.sh b/benchmark/blog_v0_2/405b_sglang.sh index 4e3372ae8..491853782 100644 --- a/benchmark/blog_v0_2/405b_sglang.sh +++ b/benchmark/blog_v0_2/405b_sglang.sh @@ -6,7 +6,7 @@ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # Launch sglang -# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 +# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87 # offline python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 diff --git a/python/sglang/srt/managers/configure_logging.py b/python/sglang/srt/managers/configure_logging.py new file mode 100644 index 000000000..3351cdc40 --- /dev/null +++ b/python/sglang/srt/managers/configure_logging.py @@ -0,0 +1,43 @@ +""" +Copyright 2023-2025 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Configure the logging settings of a server. + +Usage: +python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="http://localhost:30000") + parser.add_argument( + "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" + ) + parser.add_argument("--dump-requests-threshold", type=int, default=1000) + args = parser.parse_args() + + response = requests.post( + args.url + "/configure_logging", + json={ + "dump_requests_folder": args.dump_requests_folder, + "dump_requests_threshold": args.dump_requests_threshold, + }, + ) + assert response.status_code == 200 diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index ec45696bf..075693c7b 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -488,6 +488,13 @@ class ProfileReq(Enum): STOP_PROFILE = 2 +@dataclass +class ConfigureLoggingReq: + log_requests: Optional[bool] = None + dump_requests_folder: Optional[str] = None + dump_requests_threshold: Optional[int] = None + + @dataclass class OpenSessionReqInput: capacity_of_str_len: int diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b9e74aa9d..187216353 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -82,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( broadcast_pyobj, configure_logger, @@ -92,7 +93,6 @@ from sglang.srt.utils import ( set_random_seed, suppress_other_loggers, ) -from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 33968e34f..acd3b674a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,10 +18,12 @@ import copy import dataclasses import logging import os +import pickle import signal import sys import time import uuid +from datetime import datetime from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union import fastapi @@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import ( BatchStrOut, BatchTokenIDOut, CloseSessionReqInput, + ConfigureLoggingReq, EmbeddingReqInput, FlushCacheReq, GenerateReqInput, @@ -109,6 +112,7 @@ class TokenizerManager: # Parse args self.server_args = server_args self.enable_metrics = server_args.enable_metrics + self.log_requests = server_args.log_requests # Init inter-process communication context = zmq.asyncio.Context(2) @@ -167,6 +171,9 @@ class TokenizerManager: # Store states self.to_create_loop = True self.rid_to_state: Dict[str, ReqState] = {} + self.dump_requests_folder = "" # By default do not dump + self.dump_requests_threshold = 1000 + self.dump_request_list: List[Tuple] = [] # The event to notify the weight sync is finished. self.model_update_lock = RWLock() @@ -225,7 +232,7 @@ class TokenizerManager: obj.normalize_batch_and_arguments() - if self.server_args.log_requests: + if self.log_requests: logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}") async with self.model_update_lock.reader_lock: @@ -346,7 +353,7 @@ class TokenizerManager: state.out_list = [] if state.finished: - if self.server_args.log_requests: + if self.log_requests: msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}" logger.info(msg) del self.rid_to_state[obj.rid] @@ -597,6 +604,15 @@ class TokenizerManager: assert not self.to_create_loop, "close session should not be the first request" await self.send_to_scheduler.send_pyobj(obj) + def configure_logging(self, obj: ConfigureLoggingReq): + if obj.log_requests is not None: + self.log_requests = obj.log_requests + if obj.dump_requests_folder is not None: + self.dump_requests_folder = obj.dump_requests_folder + if obj.dump_requests_threshold is not None: + self.dump_requests_threshold = obj.dump_requests_threshold + logging.info(f"Config logging: {obj=}") + def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. async def abort_request(): @@ -708,6 +724,8 @@ class TokenizerManager: if self.enable_metrics: self.collect_metrics(state, recv_obj, i) + if self.dump_requests_folder and state.finished: + self.dump_requests(state, out_dict) elif isinstance(recv_obj, OpenSessionReqOutput): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id if recv_obj.success else None @@ -850,6 +868,25 @@ class TokenizerManager: (time.time() - state.created_time) / completion_tokens ) + def dump_requests(self, state: ReqState, out_dict: dict): + self.dump_request_list.append( + (state.obj, out_dict, state.created_time, time.time()) + ) + + if len(self.dump_request_list) >= self.dump_requests_threshold: + to_dump = self.dump_request_list + self.dump_request_list = [] + + def background_task(): + os.makedirs(self.dump_requests_folder, exist_ok=True) + current_time = datetime.now() + filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl" + with open(os.path.join(self.dump_requests_folder, filename), "wb") as f: + pickle.dump(to_dump, f) + + # Schedule the task to run in the background without awaiting it + asyncio.create_task(asyncio.to_thread(background_task)) + class SignalHandler: def __init__(self, tokenizer_manager): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 0761169e4..ab27e81b7 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter """ Memory pool. diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 190427649..238f8603a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -50,6 +50,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, @@ -60,7 +61,6 @@ from sglang.srt.utils import ( monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) -from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4e837e538..93fe1304c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -31,7 +31,7 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union import torch -from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -54,6 +54,7 @@ from sglang.srt.managers.data_parallel_controller import ( from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import ( CloseSessionReqInput, + ConfigureLoggingReq, EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, @@ -161,12 +162,68 @@ async def get_model_info(): @app.get("/get_server_info") async def get_server_info(): return { - **dataclasses.asdict(tokenizer_manager.server_args), # server args + **dataclasses.asdict(tokenizer_manager.server_args), **scheduler_info, "version": __version__, } +# fastapi implicitly converts json in the request to obj (dataclass) +@app.api_route("/generate", methods=["POST", "PUT"]) +@time_func_latency +async def generate_request(obj: GenerateReqInput, request: Request): + """Handle a generate request.""" + if obj.stream: + + async def stream_results() -> AsyncIterator[bytes]: + try: + async for out in tokenizer_manager.generate_request(obj, request): + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + except ValueError as e: + out = {"error": {"message": str(e)}} + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(obj), + ) + else: + try: + ret = await tokenizer_manager.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + logger.error(f"Error: {e}") + return _create_error_response(e) + + +@app.api_route("/encode", methods=["POST", "PUT"]) +@time_func_latency +async def encode_request(obj: EmbeddingReqInput, request: Request): + """Handle an embedding request.""" + try: + ret = await tokenizer_manager.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.api_route("/classify", methods=["POST", "PUT"]) +@time_func_latency +async def classify_request(obj: EmbeddingReqInput, request: Request): + """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" + try: + ret = await tokenizer_manager.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + @app.post("/flush_cache") async def flush_cache(): """Flush the radix cache.""" @@ -178,8 +235,7 @@ async def flush_cache(): ) -@app.get("/start_profile") -@app.post("/start_profile") +@app.api_route("/start_profile", methods=["GET", "POST"]) async def start_profile_async(): """Start profiling.""" tokenizer_manager.start_profile() @@ -189,8 +245,7 @@ async def start_profile_async(): ) -@app.get("/stop_profile") -@app.post("/stop_profile") +@app.api_route("/stop_profile", methods=["GET", "POST"]) async def stop_profile_async(): """Stop profiling.""" tokenizer_manager.stop_profile() @@ -305,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request): return _create_error_response(e) -# fastapi implicitly converts json in the request to obj (dataclass) -@app.api_route("/generate", methods=["POST", "PUT"]) -@time_func_latency -async def generate_request(obj: GenerateReqInput, request: Request): - """Handle a generate request.""" - if obj.stream: - - async def stream_results() -> AsyncIterator[bytes]: - try: - async for out in tokenizer_manager.generate_request(obj, request): - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - except ValueError as e: - out = {"error": {"message": str(e)}} - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - yield b"data: [DONE]\n\n" - - return StreamingResponse( - stream_results(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(obj), - ) - else: - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - logger.error(f"Error: {e}") - return _create_error_response(e) - - -@app.api_route("/encode", methods=["POST", "PUT"]) -@time_func_latency -async def encode_request(obj: EmbeddingReqInput, request: Request): - """Handle an embedding request.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return _create_error_response(e) - - -@app.api_route("/classify", methods=["POST", "PUT"]) -@time_func_latency -async def classify_request(obj: EmbeddingReqInput, request: Request): - """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return _create_error_response(e) +@app.api_route("/configure_logging", methods=["GET", "POST"]) +async def configure_logging(obj: ConfigureLoggingReq, request: Request): + """Close the session""" + tokenizer_manager.configure_logging(obj) + return Response(status_code=200) ##### OpenAI-compatible API endpoints ##### diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4f44d5c87..57a82c18a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -91,7 +91,7 @@ class ServerArgs: # API related api_key: Optional[str] = None - file_storage_pth: str = "SGLang_storage" + file_storage_pth: str = "sglang_storage" enable_cache_report: bool = False # Data parallelism @@ -554,7 +554,7 @@ class ServerArgs: "--decode-log-interval", type=int, default=ServerArgs.decode_log_interval, - help="The log interval of decode batch", + help="The log interval of decode batch.", ) # API related diff --git a/python/sglang/torch_memory_saver_adapter.py b/python/sglang/srt/torch_memory_saver_adapter.py similarity index 100% rename from python/sglang/torch_memory_saver_adapter.py rename to python/sglang/srt/torch_memory_saver_adapter.py