Add the ability to enable and disable the Profiler via HTTP API. (#1626)

This commit is contained in:
科英
2024-10-11 17:34:25 +08:00
committed by GitHub
parent b503881bd2
commit bbd72bfc86
4 changed files with 73 additions and 0 deletions

View File

@@ -20,6 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
@@ -343,3 +344,8 @@ class UpdateWeightReqOutput:
class AbortReq: class AbortReq:
# The request id # The request id
rid: str rid: str
class ProfileReq(Enum):
START_PROFILE = 1
STOP_PROFILE = 2

View File

@@ -37,6 +37,7 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchTokenIDOut, BatchTokenIDOut,
FlushCacheReq, FlushCacheReq,
ProfileReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
TokenizedRewardReqInput, TokenizedRewardReqInput,
@@ -229,6 +230,22 @@ class Scheduler:
self.new_token_ratio_decay = global_config.new_token_ratio_decay self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.batch_is_full = False self.batch_is_full = False
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
else:
self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
logger.info(
"Profiling enabled. Traces will be saved to: %s",
self.torch_profiler_trace_dir,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
)
@torch.inference_mode() @torch.inference_mode()
def event_loop(self): def event_loop(self):
while True: while True:
@@ -271,6 +288,11 @@ class Scheduler:
elif isinstance(recv_req, UpdateWeightReqInput): elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req) success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message)) self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
else: else:
raise ValueError(f"Invalid request: {recv_req}") raise ValueError(f"Invalid request: {recv_req}")
@@ -1000,6 +1022,20 @@ class Scheduler:
logger.error(message) logger.error(message)
return success, message return success, message
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
self.profiler.export_chrome_trace(
self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
)
logger.info("Profiler is done")
def run_scheduler_process( def run_scheduler_process(
server_args: ServerArgs, server_args: ServerArgs,

View File

@@ -46,6 +46,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
ProfileReq,
RewardReqInput, RewardReqInput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
@@ -512,6 +513,14 @@ class TokenizerManager:
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
req = ProfileReq.START_PROFILE
self.send_to_scheduler.send_pyobj(req)
def stop_profile(self):
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def update_weights( async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
): ):

View File

@@ -145,6 +145,28 @@ async def flush_cache():
) )
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile():
"""Start profiling."""
tokenizer_manager.start_profile()
return Response(
content="Start profiling.\n",
status_code=200,
)
@app.get("/stop_profile")
@app.post("/stop_profile")
async def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
return Response(
content="Stop profiling. This will take some time.\n",
status_code=200,
)
@app.post("/update_weights") @app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server.""" """Update the weights inplace without re-launching the server."""