107 lines
3.1 KiB
Python
107 lines
3.1 KiB
Python
# Device information
|
|
# replacing `torch.cuda.func`` with `torch_vacc.vacc.func`.
|
|
# see https://pytorch.org/docs/stable/cuda.html
|
|
|
|
from typing import Any
|
|
import warnings
|
|
|
|
import torch
|
|
import torch_vacc
|
|
from torch._utils import _get_device_index
|
|
from torch_vacc._vacc_libs import _torch_vacc
|
|
|
|
from .lazy_initialize import _lazy_init
|
|
|
|
if hasattr(_torch_vacc, "_exchange_device"):
|
|
_exchange_device = _torch_vacc._exchange_device
|
|
else:
|
|
|
|
def _exchange_device(device: int) -> int:
|
|
return _torch_vacc._exchange_device()
|
|
if device < 0:
|
|
return -1
|
|
prev_device = current_device()
|
|
if device != prev_device:
|
|
set_device(device)
|
|
return prev_device
|
|
|
|
|
|
class device(object):
|
|
"""Context-manager that changes the selected device.
|
|
|
|
Args:
|
|
device (torch.device or int): device index to select. It's a no-op if
|
|
this argument is a negative integer or ``None``.
|
|
"""
|
|
|
|
def __init__(self, device: Any):
|
|
self.idx = _get_device_index(device, optional=True)
|
|
self.prev_idx = -1
|
|
|
|
def __enter__(self):
|
|
self.prev_idx = _exchange_device(self.idx)
|
|
|
|
def __exit__(self, *args):
|
|
_exchange_device(self.prev_idx)
|
|
return False
|
|
|
|
|
|
def is_available() -> bool:
|
|
r"""Returns whether vacc is available."""
|
|
return device_count() > 0
|
|
|
|
def is_bf16_supported() -> bool:
|
|
r"""Returns a bool indicating if the current vacc device supports dtype bfloat16"""
|
|
return True
|
|
|
|
def current_device() -> int:
|
|
r"""Returns the index of a currently selected vacc device."""
|
|
_lazy_init()
|
|
return _torch_vacc._current_device()
|
|
|
|
|
|
def set_device(device: torch.device):
|
|
device_index = _get_device_index(device, optional=True)
|
|
if device_index >= 0:
|
|
_torch_vacc._set_device(device_index)
|
|
|
|
|
|
def get_device_capability(device=None):
|
|
r"""Query the minor and major data of device. Cann does not
|
|
have a corresponding concept and is not supported. By default, it returns None
|
|
"""
|
|
_infos = "torch.vacc.get_device_capability isn't implemented! Please do the version check in other ways, Unlike CUDA major,min"
|
|
raise AssertionError(_infos)
|
|
|
|
|
|
def get_device_name(device_name=None):
|
|
device_id = _get_device_index(device_name, optional=True)
|
|
if device_id < 0 or device_id >= device_count():
|
|
raise AssertionError("Invalid device id")
|
|
_lazy_init()
|
|
device_prop = _torch_vacc._vacc_getDeviceProperties(device_id)
|
|
return device_prop.name
|
|
|
|
|
|
def get_device_properties(device_name=None):
|
|
device_id = _get_device_index(device_name, optional=True)
|
|
if device_id < 0 or device_id >= device_count():
|
|
raise AssertionError("Invalid device id")
|
|
_lazy_init()
|
|
return _torch_vacc._vacc_getDeviceProperties(device_id)
|
|
|
|
|
|
def device_count():
|
|
r"""Returns the number of available vacc devices"""
|
|
return _torch_vacc._device_count()
|
|
|
|
|
|
def synchronize(device=None) -> None:
|
|
"""Waits for all operations in all streams on a VACC device to complete."""
|
|
_lazy_init()
|
|
with torch_vacc.vacc.device(device):
|
|
return _torch_vacc._device_synchronize()
|
|
|
|
|
|
# Memory management (https://pytorch.org/docs/stable/cuda.html#memory-management)
|