Files
enginex-vastai-va16-vllm/torch_vacc/vacc/streams.py
2026-04-02 04:55:00 +00:00

328 lines
10 KiB
Python

import ctypes
from typing import Any, Optional
import torch
from packaging import version
from torch._utils import _get_device_index
try:
from torch._streambase import _StreamBase, _EventBase
except ImportError:
# torch <= 2.1
_StreamBase = _EventBase = object
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
from ._device import device
from .lazy_initialize import _lazy_init
# remove torch version arch-suffix(i.e. +cpu)
torch_version = torch.__version__.split('+')[0]
class _StreamCommon:
"""Wrapper around a VACC stream.
A VACC stream is a linear sequence of execution that belongs to a specific
device, independent from other streams.
Args:
device(torch.device or int, optional): a device on which to allocate
the stream. If :attr:`device` is ``None`` (default) or a negative
integer, this will use the current device.
priority(int, optional): priority of the stream. Can be either
-1 (high priority) or 0 (low priority). By default, streams have
priority 0.
"""
def __new__(cls, device=None, priority=0, **kwargs):
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
else:
with torch_vacc.vacc.device(device):
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
def wait_event(self, event):
event.wait(self)
def record_event(self, event=None):
"""Records an event.
Args:
event (torch_vacc.Event, optional): event to record. If not given, a new one
will be allocated.
Returns:
Recorded event.
"""
if event is None:
event = Event()
event.record(self)
return event
def wait_stream(self, stream):
"""Synchronizes with another stream.
All future work submitted to this stream will wait until all kernels
submitted to a given stream at the time of call complete.
Args:
stream (Stream): a stream to synchronize.
"""
self.wait_event(stream.record_event())
def query(self):
return super().query()
def synchronize(self):
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.vacc_stream)
def __eq__(self, o):
if isinstance(o, Stream):
return super().__eq__(o)
return False
def __hash__(self):
return hash((self.vacc_stream, self.device))
def __repr__(self):
return f"torch_vacc.vacc.Stream device={self.device} vacc_stream={self.vacc_stream:#x}"
if version.parse(torch_version) <= version.parse("2.1"):
# torch <= 2.1
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
pass
elif version.parse(torch_version) < version.parse("2.6"):
# torch < 2.6
class Stream(_torch_vacc._VACCStreamBase, _StreamBase, _StreamCommon):
pass
else:
# torch >= 2.6
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
pass
class _EventCommon:
"""Wrapper around a VACC event.
VACC events are synchronization markers that can be used to monitor the
device's progress, to accurately measure timing, and to synchronize VACC
streams.
The underlying VACC events are lazily initialized when the event is first
recorded or exported to another process. After creation, only streams on the
same device may record the event. However, streams on any device can wait on
the event.
Args:
calc_time (bool, optional): indicates if the event should measure time
(default: ``False``)
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
"""
def __new__(cls, enable_timing=False, blocking=False):
return super(Event, cls).__new__(
cls,
calc_time=enable_timing,
blocking=blocking,
)
def record(self, stream=None):
"""Records the event in a given stream.
Uses ``torch_vacc.vacc.current_stream()`` if no stream is specified. The
stream's device must match the event's device."""
if stream is None:
stream = torch_vacc.vacc.current_stream()
super().record(stream)
def wait(self, stream=None):
"""Makes all future work submitted to the given stream wait for this
event.
Use ``torch_vacc.vacc.current_stream()`` if no stream is specified.
.. note:: This is a wrapper around ``vaccrtStreamWaitEvent()``
"""
if stream is None:
stream = torch_vacc.vacc.current_stream()
super().wait(stream)
def query(self):
"""Checks if all work currently captured by event has completed.
Returns:
A boolean indicating if all work currently captured by event has
completed.
"""
return super().query()
def elapsed_time(self, end_event):
"""Returns the time elapsed in milliseconds after the event was
recorded and before the end_event was recorded.
"""
return super().elapsed_time(end_event)
def synchronize(self):
r"""Waits for the event to complete.
Waits until the completion of all work currently captured in this event.
This prevents the CPU thread from proceeding until the event completes.
.. note:: This is a wrapper around ``vaccEventSynchronize()``.
"""
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.vacc_event)
def __repr__(self):
if self.vacc_event:
return f"<torch_vacc.vacc.Event {self._as_parameter_.value:#x}>"
else:
return "<torch_vacc.vacc.Event uninitialized>"
if version.parse(torch_version) <= version.parse("2.1"):
# torch <= 2.1
class Event(_torch_vacc._VACCEventBase, _EventCommon):
pass
elif version.parse(torch_version) < version.parse("2.6"):
# torch < 2.6
class Event(_torch_vacc._VACCEventBase, _EventBase, _EventCommon):
pass
else:
# torch >= 2.6
class Event(_torch_vacc._VACCEventBase, _EventCommon):
pass
class StreamContext:
r"""Context-manager that selects a given stream.
All VACC kernels queued within its context will be enqueued on a selected
stream.
Args:
stream (stream): selected stream. This manager is a no-op if it's
``None``.
.. note:: Streams are per-device.
"""
cur_stream: Optional["torch_vacc.vacc.Stream"]
def __init__(self, stream: Optional["torch_vacc.vacc.Stream"]):
self.stream = stream
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
if self.idx is None:
self.idx = -1
self.src_prev_stream = (
None
if not torch.jit.is_scripting()
else torch_vacc.vacc.default_stream(None)
)
self.dst_prev_stream = (
None
if not torch.jit.is_scripting()
else torch_vacc.vacc.default_stream(None)
)
def __enter__(self):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# Return if stream is None or VACC device not available
if cur_stream is None or self.idx == -1:
return
self.src_prev_stream = torch_vacc.vacc.current_stream(None)
# If the stream is not on the current device, then
# set the current stream on the device
if self.src_prev_stream.device != cur_stream.device:
with device(cur_stream.device):
self.dst_prev_stream = torch_vacc.vacc.current_stream(cur_stream.device)
torch_vacc.vacc.set_stream(cur_stream)
def __exit__(self, type: Any, value: Any, traceback: Any):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# If stream is None or no VACC device available, return
if cur_stream is None or self.idx == -1:
return
# Reset the stream on the original device
# and destination device
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
torch_vacc.vacc.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
torch_vacc.vacc.set_stream(self.src_prev_stream) # type: ignore[arg-type]
def stream(stream: Optional["torch_vacc.vacc.Stream"]) -> StreamContext:
r"""Wrapper around the Context-manager StreamContext that
selects a given stream.
Arguments:
stream (Stream): selected stream. This manager is a no-op if it's
``None``.
"""
return StreamContext(stream)
def set_stream(stream: Stream):
r"""Sets the current stream.This is a wrapper API to set the stream.
Usage of this function is discouraged in favor of the ``stream``
context manager.
Args:
stream (Stream): selected stream. This function is a no-op
if this argument is ``None``.
"""
if stream is None:
return
_torch_vacc._vacc_setStream(
stream_id=stream.stream_id,
device_index=stream.device_index,
device_type=stream.device_type,
)
def current_stream(device=None) -> Stream:
r"""Returns the currently selected :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the currently selected :class:`Stream` for the current device, given
by :func:`~torch_vacc.vacc.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
streamdata = _torch_vacc._vacc_getCurrentStream(
_get_device_index(device, optional=True)
)
return Stream(
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
)
def default_stream(device=None) -> Stream:
r"""Returns the default :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the default :class:`Stream` for the current device, given by
:func:`_torch_vacc.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
streamdata = _torch_vacc._vacc_getDefaultStream(
_get_device_index(device, optional=True)
)
return Stream(
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
)