[feat] Enable Ascend profiling on SGLang (#8610)
Co-authored-by: liyou_b <2953090824@qq.com>
This commit is contained in:
@@ -8,6 +8,18 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
from sglang.srt.utils import is_npu
|
||||||
|
|
||||||
|
_is_npu = is_npu()
|
||||||
|
if _is_npu:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
patches = [
|
||||||
|
["profiler.profile", torch_npu.profiler.profile],
|
||||||
|
["profiler.ProfilerActivity.CUDA", torch_npu.profiler.ProfilerActivity.NPU],
|
||||||
|
["profiler.ProfilerActivity.CPU", torch_npu.profiler.ProfilerActivity.CPU],
|
||||||
|
]
|
||||||
|
torch_npu._apply_patches(patches)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -136,6 +148,13 @@ class SchedulerProfilerMixin:
|
|||||||
activities=torchprof_activities,
|
activities=torchprof_activities,
|
||||||
with_stack=with_stack if with_stack is not None else True,
|
with_stack=with_stack if with_stack is not None else True,
|
||||||
record_shapes=record_shapes if record_shapes is not None else False,
|
record_shapes=record_shapes if record_shapes is not None else False,
|
||||||
|
on_trace_ready=(
|
||||||
|
None
|
||||||
|
if not _is_npu
|
||||||
|
else torch_npu.profiler.tensorboard_trace_handler(
|
||||||
|
self.torch_profiler_output_dir
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.torch_profiler.start()
|
self.torch_profiler.start()
|
||||||
self.profile_in_progress = True
|
self.profile_in_progress = True
|
||||||
@@ -166,15 +185,16 @@ class SchedulerProfilerMixin:
|
|||||||
logger.info("Stop profiling" + stage_suffix + "...")
|
logger.info("Stop profiling" + stage_suffix + "...")
|
||||||
if self.torch_profiler is not None:
|
if self.torch_profiler is not None:
|
||||||
self.torch_profiler.stop()
|
self.torch_profiler.stop()
|
||||||
self.torch_profiler.export_chrome_trace(
|
if not _is_npu:
|
||||||
os.path.join(
|
self.torch_profiler.export_chrome_trace(
|
||||||
self.torch_profiler_output_dir,
|
os.path.join(
|
||||||
self.profile_id
|
self.torch_profiler_output_dir,
|
||||||
+ f"-TP-{self.tp_rank}"
|
self.profile_id
|
||||||
+ stage_suffix
|
+ f"-TP-{self.tp_rank}"
|
||||||
+ ".trace.json.gz",
|
+ stage_suffix
|
||||||
|
+ ".trace.json.gz",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
torch.distributed.barrier(self.tp_cpu_group)
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
|
|
||||||
if self.rpd_profiler is not None:
|
if self.rpd_profiler is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user