Add the ability to enable and disable the Profiler via HTTP API. (#1626)
This commit is contained in:
@@ -20,6 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
@@ -343,3 +344,8 @@ class UpdateWeightReqOutput:
|
|||||||
class AbortReq:
|
class AbortReq:
|
||||||
# The request id
|
# The request id
|
||||||
rid: str
|
rid: str
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileReq(Enum):
|
||||||
|
START_PROFILE = 1
|
||||||
|
STOP_PROFILE = 2
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
|
ProfileReq,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
TokenizedRewardReqInput,
|
TokenizedRewardReqInput,
|
||||||
@@ -229,6 +230,22 @@ class Scheduler:
|
|||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
||||||
|
self.profiler = None
|
||||||
|
else:
|
||||||
|
self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
||||||
|
logger.info(
|
||||||
|
"Profiling enabled. Traces will be saved to: %s",
|
||||||
|
self.torch_profiler_trace_dir,
|
||||||
|
)
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def event_loop(self):
|
def event_loop(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -271,6 +288,11 @@ class Scheduler:
|
|||||||
elif isinstance(recv_req, UpdateWeightReqInput):
|
elif isinstance(recv_req, UpdateWeightReqInput):
|
||||||
success, message = self.update_weights(recv_req)
|
success, message = self.update_weights(recv_req)
|
||||||
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
||||||
|
elif isinstance(recv_req, ProfileReq):
|
||||||
|
if recv_req == ProfileReq.START_PROFILE:
|
||||||
|
self.start_profile()
|
||||||
|
else:
|
||||||
|
self.stop_profile()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid request: {recv_req}")
|
raise ValueError(f"Invalid request: {recv_req}")
|
||||||
|
|
||||||
@@ -1000,6 +1022,20 @@ class Scheduler:
|
|||||||
logger.error(message)
|
logger.error(message)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def start_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
self.profiler.start()
|
||||||
|
|
||||||
|
def stop_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
self.profiler.stop()
|
||||||
|
self.profiler.export_chrome_trace(
|
||||||
|
self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
)
|
||||||
|
logger.info("Profiler is done")
|
||||||
|
|
||||||
|
|
||||||
def run_scheduler_process(
|
def run_scheduler_process(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
|
ProfileReq,
|
||||||
RewardReqInput,
|
RewardReqInput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -512,6 +513,14 @@ class TokenizerManager:
|
|||||||
req = AbortReq(rid)
|
req = AbortReq(rid)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
req = ProfileReq.START_PROFILE
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
req = ProfileReq.STOP_PROFILE
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
async def update_weights(
|
async def update_weights(
|
||||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -145,6 +145,28 @@ async def flush_cache():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/start_profile")
|
||||||
|
@app.post("/start_profile")
|
||||||
|
async def start_profile():
|
||||||
|
"""Start profiling."""
|
||||||
|
tokenizer_manager.start_profile()
|
||||||
|
return Response(
|
||||||
|
content="Start profiling.\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stop_profile")
|
||||||
|
@app.post("/stop_profile")
|
||||||
|
async def stop_profile():
|
||||||
|
"""Stop profiling."""
|
||||||
|
tokenizer_manager.stop_profile()
|
||||||
|
return Response(
|
||||||
|
content="Stop profiling. This will take some time.\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights")
|
@app.post("/update_weights")
|
||||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||||
"""Update the weights inplace without re-launching the server."""
|
"""Update the weights inplace without re-launching the server."""
|
||||||
|
|||||||
Reference in New Issue
Block a user