diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py index 3d061a8fe..afbab8205 100644 --- a/python/sglang/srt/managers/scheduler_profiler_mixin.py +++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py @@ -8,6 +8,18 @@ import torch from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.utils import is_npu + +_is_npu = is_npu() +if _is_npu: + import torch_npu + + patches = [ + ["profiler.profile", torch_npu.profiler.profile], + ["profiler.ProfilerActivity.CUDA", torch_npu.profiler.ProfilerActivity.NPU], + ["profiler.ProfilerActivity.CPU", torch_npu.profiler.ProfilerActivity.CPU], + ] + torch_npu._apply_patches(patches) logger = logging.getLogger(__name__) @@ -136,6 +148,13 @@ class SchedulerProfilerMixin: activities=torchprof_activities, with_stack=with_stack if with_stack is not None else True, record_shapes=record_shapes if record_shapes is not None else False, + on_trace_ready=( + None + if not _is_npu + else torch_npu.profiler.tensorboard_trace_handler( + self.torch_profiler_output_dir + ) + ), ) self.torch_profiler.start() self.profile_in_progress = True @@ -166,15 +185,16 @@ class SchedulerProfilerMixin: logger.info("Stop profiling" + stage_suffix + "...") if self.torch_profiler is not None: self.torch_profiler.stop() - self.torch_profiler.export_chrome_trace( - os.path.join( - self.torch_profiler_output_dir, - self.profile_id - + f"-TP-{self.tp_rank}" - + stage_suffix - + ".trace.json.gz", + if not _is_npu: + self.torch_profiler.export_chrome_trace( + os.path.join( + self.torch_profiler_output_dir, + self.profile_id + + f"-TP-{self.tp_rank}" + + stage_suffix + + ".trace.json.gz", + ) ) - ) torch.distributed.barrier(self.tp_cpu_group) if self.rpd_profiler is not None: