diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 8dd7c6bc..5cccdaf0 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -60,6 +60,65 @@ _IS_MOE_MODEL = None _ENABLE_SP = None _HAS_LAYER_IDX = None _ENABLE_NZ = None +_SUBSCRIBED_COMPUTE_STREAMS = set() +_GRAPH_PRINT_STREAM = None +_GRAPH_PRINT_STREAM_LOCK = Lock() + + +def _print_callback_on_stream(*args): + """Callback function to print arguments on the dedicated print stream.""" + global _GRAPH_PRINT_STREAM + with torch_npu.npu.stream(_GRAPH_PRINT_STREAM): + print(*args, flush=True) + + +def acl_graph_print(*args): + """ + Prints arguments from within an ACL graph. + + This function is provided for developers to print debug information when encountering + issues within an ACL graph, pretty handy for dumping input/output tensor values, or + resolving unexpected hangs. Usage: + ```python + from vllm_ascend.utils import acl_graph_print + ... + acl_graph_print("Debug info") + ``` + + This function launches a host function on the current compute stream to print + the given arguments. It uses a dedicated stream for printing to avoid + interfering with computation. + + NOTE: torch.compile does not support this function, only use this in non-compiled code. + For example, those custom ops like `unified_attention_with_output` or `moe_forward`. + """ + global _SUBSCRIBED_COMPUTE_STREAMS + global _GRAPH_PRINT_STREAM + + current_compute_stream = torch_npu.npu.current_stream() + + with _GRAPH_PRINT_STREAM_LOCK: + if _GRAPH_PRINT_STREAM is None: + _GRAPH_PRINT_STREAM = torch_npu.npu.Stream() + + if current_compute_stream not in _SUBSCRIBED_COMPUTE_STREAMS: + # Subscribe the compute stream to allow launching host functions. + torch_npu.npu._subscribe_report(current_compute_stream) + _SUBSCRIBED_COMPUTE_STREAMS.add(current_compute_stream) + + torch_npu.npu._launch_host_func(current_compute_stream, + _print_callback_on_stream, args) + + +def _unregister_print_streams_on_exit(): + """Unsubscribe all compute streams used for printing at exit.""" + global _SUBSCRIBED_COMPUTE_STREAMS + with _GRAPH_PRINT_STREAM_LOCK: + for stream in _SUBSCRIBED_COMPUTE_STREAMS: + torch_npu.npu._unsubscribe_report(stream) + + +atexit.register(_unregister_print_streams_on_exit) def is_310p():