Updated Instructions on Profiling SGLang Infer System with AMD GPUs (#1966)
Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
126
3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch
vendored
Normal file
126
3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 62d1ff9..2edb427 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -245,6 +247,7 @@ class Scheduler:
|
||||
],
|
||||
with_stack=True,
|
||||
)
|
||||
+ self.rpd = rpdTracerControl()
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
@@ -1027,15 +1030,26 @@ class Scheduler:
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.start()
|
||||
+ #self.profiler.start()
|
||||
+ logger.info("torch profiler is disabled")
|
||||
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||
+ self.rpd.setPythonTrace(True)
|
||||
+ self.rpd.start()
|
||||
+ self.rpd.rangePush("", "scheduler", "")
|
||||
+ logger.info("rpd is enabled inside scheduler profiling")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.stop()
|
||||
- self.profiler.export_chrome_trace(
|
||||
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
- )
|
||||
+ #self.profiler.stop()
|
||||
+ #self.profiler.export_chrome_trace(
|
||||
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
+ #)
|
||||
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||
+ self.rpd.rangePop()
|
||||
+ self.rpd.stop()
|
||||
+ self.rpd.flush()
|
||||
+ logger.info("rpd is done inside scheduler")
|
||||
logger.info("Profiler is done")
|
||||
|
||||
|
||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
index 2621ccd..181df85 100644
|
||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
||||
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
+
|
||||
+
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -514,10 +518,20 @@ class TokenizerManager:
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_profile(self):
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.setPythonTrace(True)
|
||||
+ rpd.start()
|
||||
+ rpd.rangePush("", "tokenizer_manager", "")
|
||||
+ logger.info("tokenizer_manager rpd profiling started!")
|
||||
req = ProfileReq.START_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def stop_profile(self):
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.rangePop()
|
||||
+ rpd.stop()
|
||||
+ rpd.flush()
|
||||
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
||||
req = ProfileReq.STOP_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
||||
index 7111c93..2bd722c 100644
|
||||
--- a/python/sglang/srt/server.py
|
||||
+++ b/python/sglang/srt/server.py
|
||||
@@ -30,6 +30,8 @@ import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -152,6 +154,11 @@ async def flush_cache():
|
||||
@app.post("/start_profile")
|
||||
async def start_profile():
|
||||
"""Start profiling."""
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.setPythonTrace(True)
|
||||
+ rpd.start()
|
||||
+ rpd.rangePush("", "server rpd profile range", "")
|
||||
+ logger.info("rpd profiling started in server.py!")
|
||||
tokenizer_manager.start_profile()
|
||||
return Response(
|
||||
content="Start profiling.\n",
|
||||
@@ -164,6 +171,11 @@ async def start_profile():
|
||||
async def stop_profile():
|
||||
"""Stop profiling."""
|
||||
tokenizer_manager.stop_profile()
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.rangePop()
|
||||
+ rpd.stop()
|
||||
+ rpd.flush()
|
||||
+ logger.info("rpd profiling is done in server.py!")
|
||||
return Response(
|
||||
content="Stop profiling. This will take some time.\n",
|
||||
status_code=200,
|
||||
Reference in New Issue
Block a user