import ctypes from typing import Any, Optional import torch from packaging import version from torch._utils import _get_device_index try: from torch._streambase import _StreamBase, _EventBase except ImportError: # torch <= 2.1 _StreamBase = _EventBase = object import torch_vacc from torch_vacc._vacc_libs import _torch_vacc from ._device import device from .lazy_initialize import _lazy_init # remove torch version arch-suffix(i.e. +cpu) torch_version = torch.__version__.split('+')[0] class _StreamCommon: """Wrapper around a VACC stream. A VACC stream is a linear sequence of execution that belongs to a specific device, independent from other streams. Args: device(torch.device or int, optional): a device on which to allocate the stream. If :attr:`device` is ``None`` (default) or a negative integer, this will use the current device. priority(int, optional): priority of the stream. Can be either -1 (high priority) or 0 (low priority). By default, streams have priority 0. """ def __new__(cls, device=None, priority=0, **kwargs): if device is None or ("stream_id" in kwargs and "device_index" in kwargs): return super(Stream, cls).__new__(cls, priority=priority, **kwargs) else: with torch_vacc.vacc.device(device): return super(Stream, cls).__new__(cls, priority=priority, **kwargs) def wait_event(self, event): event.wait(self) def record_event(self, event=None): """Records an event. Args: event (torch_vacc.Event, optional): event to record. If not given, a new one will be allocated. Returns: Recorded event. """ if event is None: event = Event() event.record(self) return event def wait_stream(self, stream): """Synchronizes with another stream. All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call complete. Args: stream (Stream): a stream to synchronize. """ self.wait_event(stream.record_event()) def query(self): return super().query() def synchronize(self): super().synchronize() @property def _as_parameter_(self): return ctypes.c_void_p(self.vacc_stream) def __eq__(self, o): if isinstance(o, Stream): return super().__eq__(o) return False def __hash__(self): return hash((self.vacc_stream, self.device)) def __repr__(self): return f"torch_vacc.vacc.Stream device={self.device} vacc_stream={self.vacc_stream:#x}" if version.parse(torch_version) <= version.parse("2.1"): # torch <= 2.1 class Stream(_torch_vacc._VACCStreamBase, _StreamCommon): pass elif version.parse(torch_version) < version.parse("2.6"): # torch < 2.6 class Stream(_torch_vacc._VACCStreamBase, _StreamBase, _StreamCommon): pass else: # torch >= 2.6 class Stream(_torch_vacc._VACCStreamBase, _StreamCommon): pass class _EventCommon: """Wrapper around a VACC event. VACC events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize VACC streams. The underlying VACC events are lazily initialized when the event is first recorded or exported to another process. After creation, only streams on the same device may record the event. However, streams on any device can wait on the event. Args: calc_time (bool, optional): indicates if the event should measure time (default: ``False``) blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) """ def __new__(cls, enable_timing=False, blocking=False): return super(Event, cls).__new__( cls, calc_time=enable_timing, blocking=blocking, ) def record(self, stream=None): """Records the event in a given stream. Uses ``torch_vacc.vacc.current_stream()`` if no stream is specified. The stream's device must match the event's device.""" if stream is None: stream = torch_vacc.vacc.current_stream() super().record(stream) def wait(self, stream=None): """Makes all future work submitted to the given stream wait for this event. Use ``torch_vacc.vacc.current_stream()`` if no stream is specified. .. note:: This is a wrapper around ``vaccrtStreamWaitEvent()`` """ if stream is None: stream = torch_vacc.vacc.current_stream() super().wait(stream) def query(self): """Checks if all work currently captured by event has completed. Returns: A boolean indicating if all work currently captured by event has completed. """ return super().query() def elapsed_time(self, end_event): """Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded. """ return super().elapsed_time(end_event) def synchronize(self): r"""Waits for the event to complete. Waits until the completion of all work currently captured in this event. This prevents the CPU thread from proceeding until the event completes. .. note:: This is a wrapper around ``vaccEventSynchronize()``. """ super().synchronize() @property def _as_parameter_(self): return ctypes.c_void_p(self.vacc_event) def __repr__(self): if self.vacc_event: return f"" else: return "" if version.parse(torch_version) <= version.parse("2.1"): # torch <= 2.1 class Event(_torch_vacc._VACCEventBase, _EventCommon): pass elif version.parse(torch_version) < version.parse("2.6"): # torch < 2.6 class Event(_torch_vacc._VACCEventBase, _EventBase, _EventCommon): pass else: # torch >= 2.6 class Event(_torch_vacc._VACCEventBase, _EventCommon): pass class StreamContext: r"""Context-manager that selects a given stream. All VACC kernels queued within its context will be enqueued on a selected stream. Args: stream (stream): selected stream. This manager is a no-op if it's ``None``. .. note:: Streams are per-device. """ cur_stream: Optional["torch_vacc.vacc.Stream"] def __init__(self, stream: Optional["torch_vacc.vacc.Stream"]): self.stream = stream self.idx = _get_device_index(None, True) if not torch.jit.is_scripting(): if self.idx is None: self.idx = -1 self.src_prev_stream = ( None if not torch.jit.is_scripting() else torch_vacc.vacc.default_stream(None) ) self.dst_prev_stream = ( None if not torch.jit.is_scripting() else torch_vacc.vacc.default_stream(None) ) def __enter__(self): # Local cur_stream variable for type refinement cur_stream = self.stream # Return if stream is None or VACC device not available if cur_stream is None or self.idx == -1: return self.src_prev_stream = torch_vacc.vacc.current_stream(None) # If the stream is not on the current device, then # set the current stream on the device if self.src_prev_stream.device != cur_stream.device: with device(cur_stream.device): self.dst_prev_stream = torch_vacc.vacc.current_stream(cur_stream.device) torch_vacc.vacc.set_stream(cur_stream) def __exit__(self, type: Any, value: Any, traceback: Any): # Local cur_stream variable for type refinement cur_stream = self.stream # If stream is None or no VACC device available, return if cur_stream is None or self.idx == -1: return # Reset the stream on the original device # and destination device if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] torch_vacc.vacc.set_stream(self.dst_prev_stream) # type: ignore[arg-type] torch_vacc.vacc.set_stream(self.src_prev_stream) # type: ignore[arg-type] def stream(stream: Optional["torch_vacc.vacc.Stream"]) -> StreamContext: r"""Wrapper around the Context-manager StreamContext that selects a given stream. Arguments: stream (Stream): selected stream. This manager is a no-op if it's ``None``. """ return StreamContext(stream) def set_stream(stream: Stream): r"""Sets the current stream.This is a wrapper API to set the stream. Usage of this function is discouraged in favor of the ``stream`` context manager. Args: stream (Stream): selected stream. This function is a no-op if this argument is ``None``. """ if stream is None: return _torch_vacc._vacc_setStream( stream_id=stream.stream_id, device_index=stream.device_index, device_type=stream.device_type, ) def current_stream(device=None) -> Stream: r"""Returns the currently selected :class:`Stream` for a given device. Args: device (torch.device or int, optional): selected device. Returns the currently selected :class:`Stream` for the current device, given by :func:`~torch_vacc.vacc.current_device`, if :attr:`device` is ``None`` (default). """ _lazy_init() streamdata = _torch_vacc._vacc_getCurrentStream( _get_device_index(device, optional=True) ) return Stream( stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] ) def default_stream(device=None) -> Stream: r"""Returns the default :class:`Stream` for a given device. Args: device (torch.device or int, optional): selected device. Returns the default :class:`Stream` for the current device, given by :func:`_torch_vacc.current_device`, if :attr:`device` is ``None`` (default). """ _lazy_init() streamdata = _torch_vacc._vacc_getDefaultStream( _get_device_index(device, optional=True) ) return Stream( stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] )