Add the ability to enable and disable the Profiler via HTTP API. (#1626)
This commit is contained in:
@@ -37,6 +37,7 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
FlushCacheReq,
|
||||
ProfileReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
@@ -229,6 +230,22 @@ class Scheduler:
|
||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||
self.batch_is_full = False
|
||||
|
||||
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
||||
self.profiler = None
|
||||
else:
|
||||
self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s",
|
||||
self.torch_profiler_trace_dir,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
while True:
|
||||
@@ -271,6 +288,11 @@ class Scheduler:
|
||||
elif isinstance(recv_req, UpdateWeightReqInput):
|
||||
success, message = self.update_weights(recv_req)
|
||||
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
|
||||
@@ -1000,6 +1022,20 @@ class Scheduler:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.start()
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
self.profiler.export_chrome_trace(
|
||||
self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
)
|
||||
logger.info("Profiler is done")
|
||||
|
||||
|
||||
def run_scheduler_process(
|
||||
server_args: ServerArgs,
|
||||
|
||||
Reference in New Issue
Block a user