From bbd72bfc8609d1c5d8bc9ebb29b9c3b9e218bb90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A7=91=E8=8B=B1?= Date: Fri, 11 Oct 2024 17:34:25 +0800 Subject: [PATCH] Add the ability to enable and disable the Profiler via HTTP API. (#1626) --- python/sglang/srt/managers/io_struct.py | 6 ++++ python/sglang/srt/managers/scheduler.py | 36 +++++++++++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 9 +++++ python/sglang/srt/server.py | 22 ++++++++++++ 4 files changed, 73 insertions(+) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 00e1c5b7c..c9ee00e9d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -20,6 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). import uuid from dataclasses import dataclass +from enum import Enum from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason @@ -343,3 +344,8 @@ class UpdateWeightReqOutput: class AbortReq: # The request id rid: str + + +class ProfileReq(Enum): + START_PROFILE = 1 + STOP_PROFILE = 2 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 34bf39289..10411cd3e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -37,6 +37,7 @@ from sglang.srt.managers.io_struct import ( BatchEmbeddingOut, BatchTokenIDOut, FlushCacheReq, + ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, TokenizedRewardReqInput, @@ -229,6 +230,22 @@ class Scheduler: self.new_token_ratio_decay = global_config.new_token_ratio_decay self.batch_is_full = False + if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": + self.profiler = None + else: + self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR") + logger.info( + "Profiling enabled. Traces will be saved to: %s", + self.torch_profiler_trace_dir, + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + ) + @torch.inference_mode() def event_loop(self): while True: @@ -271,6 +288,11 @@ class Scheduler: elif isinstance(recv_req, UpdateWeightReqInput): success, message = self.update_weights(recv_req) self.out_pyobjs.append(UpdateWeightReqOutput(success, message)) + elif isinstance(recv_req, ProfileReq): + if recv_req == ProfileReq.START_PROFILE: + self.start_profile() + else: + self.stop_profile() else: raise ValueError(f"Invalid request: {recv_req}") @@ -1000,6 +1022,20 @@ class Scheduler: logger.error(message) return success, message + def start_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + self.profiler.export_chrome_trace( + self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" + ) + logger.info("Profiler is done") + def run_scheduler_process( server_args: ServerArgs, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b25290c0a..2621ccd4f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,7 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + ProfileReq, RewardReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -512,6 +513,14 @@ class TokenizerManager: req = AbortReq(rid) self.send_to_scheduler.send_pyobj(req) + def start_profile(self): + req = ProfileReq.START_PROFILE + self.send_to_scheduler.send_pyobj(req) + + def stop_profile(self): + req = ProfileReq.STOP_PROFILE + self.send_to_scheduler.send_pyobj(req) + async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None ): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c6b2a345b..ff3364012 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -145,6 +145,28 @@ async def flush_cache(): ) +@app.get("/start_profile") +@app.post("/start_profile") +async def start_profile(): + """Start profiling.""" + tokenizer_manager.start_profile() + return Response( + content="Start profiling.\n", + status_code=200, + ) + + +@app.get("/stop_profile") +@app.post("/stop_profile") +async def stop_profile(): + """Stop profiling.""" + tokenizer_manager.stop_profile() + return Response( + content="Stop profiling. This will take some time.\n", + status_code=200, + ) + + @app.post("/update_weights") async def update_weights(obj: UpdateWeightReqInput, request: Request): """Update the weights inplace without re-launching the server."""