# Device information # replacing `torch.cuda.func`` with `torch_vacc.vacc.func`. # see https://pytorch.org/docs/stable/cuda.html from typing import Any import warnings import torch import torch_vacc from torch._utils import _get_device_index from torch_vacc._vacc_libs import _torch_vacc from .lazy_initialize import _lazy_init if hasattr(_torch_vacc, "_exchange_device"): _exchange_device = _torch_vacc._exchange_device else: def _exchange_device(device: int) -> int: return _torch_vacc._exchange_device() if device < 0: return -1 prev_device = current_device() if device != prev_device: set_device(device) return prev_device class device(object): """Context-manager that changes the selected device. Args: device (torch.device or int): device index to select. It's a no-op if this argument is a negative integer or ``None``. """ def __init__(self, device: Any): self.idx = _get_device_index(device, optional=True) self.prev_idx = -1 def __enter__(self): self.prev_idx = _exchange_device(self.idx) def __exit__(self, *args): _exchange_device(self.prev_idx) return False def is_available() -> bool: r"""Returns whether vacc is available.""" return device_count() > 0 def is_bf16_supported() -> bool: r"""Returns a bool indicating if the current vacc device supports dtype bfloat16""" return True def current_device() -> int: r"""Returns the index of a currently selected vacc device.""" _lazy_init() return _torch_vacc._current_device() def set_device(device: torch.device): device_index = _get_device_index(device, optional=True) if device_index >= 0: _torch_vacc._set_device(device_index) def get_device_capability(device=None): r"""Query the minor and major data of device. Cann does not have a corresponding concept and is not supported. By default, it returns None """ _infos = "torch.vacc.get_device_capability isn't implemented! Please do the version check in other ways, Unlike CUDA major,min" raise AssertionError(_infos) def get_device_name(device_name=None): device_id = _get_device_index(device_name, optional=True) if device_id < 0 or device_id >= device_count(): raise AssertionError("Invalid device id") _lazy_init() device_prop = _torch_vacc._vacc_getDeviceProperties(device_id) return device_prop.name def get_device_properties(device_name=None): device_id = _get_device_index(device_name, optional=True) if device_id < 0 or device_id >= device_count(): raise AssertionError("Invalid device id") _lazy_init() return _torch_vacc._vacc_getDeviceProperties(device_id) def device_count(): r"""Returns the number of available vacc devices""" return _torch_vacc._device_count() def synchronize(device=None) -> None: """Waits for all operations in all streams on a VACC device to complete.""" _lazy_init() with torch_vacc.vacc.device(device): return _torch_vacc._device_synchronize() # Memory management (https://pytorch.org/docs/stable/cuda.html#memory-management)