From f6528b74be50e1b82bd06e895e9a382669438661 Mon Sep 17 00:00:00 2001 From: lizhigong <306128847@qq.com> Date: Tue, 28 Oct 2025 16:25:06 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0hipprof=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E3=80=81=E4=BF=AE=E5=A4=8D=E5=BC=82=E6=AD=A5=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E5=90=8C=E6=AD=A5=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/profile/prof.py | 58 ++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/profile/prof.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index be2de0cc7..c07bb4360 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None - self.seq_lens_sum = self.seq_lens.sum().item() + self.seq_lens_sum = self.seq_lens.sum() self.output_ids = self.output_ids[keep_indices_device] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: diff --git a/python/sglang/srt/profile/prof.py b/python/sglang/srt/profile/prof.py new file mode 100644 index 000000000..d82cdeb27 --- /dev/null +++ b/python/sglang/srt/profile/prof.py @@ -0,0 +1,58 @@ +from ctypes import * +import os +import time +import threading + +class Prof: + def __init__(self): + self.use_roctx = os.getenv('SGLANG_HIP_PROF') is not None + if self.use_roctx: + self.lib = cdll.LoadLibrary("libroctracer64.so") + self.lib.roctxRangePushA.argtypes = [c_char_p] + self.lib.roctxRangePushA.restype = c_int + self.lib.roctxRangePop.restype = c_int + self.tm = time.perf_counter() + self.push_depth = {} + + def StartTracer(self): + if self.use_roctx: + if self.lib is None: + self.lib = cdll.LoadLibrary("libroctracer64.so") + self.lib.roctracer_start() + self.roc_tracer_flag = True + + def StopTracer(self): + if self.use_roctx: + if self.lib is None: + self.lib = cdll.LoadLibrary("libroctracer64.so") + self.lib.roctracer_stop() + self.roc_tracer_flag = False + + def thread_depth_add(self, num): + current_thread = threading.current_thread() + thread_id = current_thread.ident + if thread_id not in self.push_depth.keys(): + self.push_depth[thread_id] = 0 + if num < 0 and self.push_depth[thread_id] == 0: + return False + self.push_depth[thread_id] += num + return True + + def ProfRangePush(self, message): + if profile.use_roctx and self.roc_tracer_flag: + profile.lib.roctxRangePushA(message.encode('utf-8')) + profile.lib.roctxRangePushA(message.encode('utf-8')) + self.thread_depth_add(1) + + def ProfRangePop(self): + if profile.use_roctx and self.roc_tracer_flag: + if not self.thread_depth_add(-1): + return + profile.lib.roctxRangePop() + + def ProfRangeAutoPush(self, message): + self.ProfRangePop() + self.ProfRangePush(message) + + +profile = Prof()