Files
2026-04-02 04:55:00 +00:00

107 lines
3.1 KiB
Python

# Device information
# replacing `torch.cuda.func`` with `torch_vacc.vacc.func`.
# see https://pytorch.org/docs/stable/cuda.html
from typing import Any
import warnings
import torch
import torch_vacc
from torch._utils import _get_device_index
from torch_vacc._vacc_libs import _torch_vacc
from .lazy_initialize import _lazy_init
if hasattr(_torch_vacc, "_exchange_device"):
_exchange_device = _torch_vacc._exchange_device
else:
def _exchange_device(device: int) -> int:
return _torch_vacc._exchange_device()
if device < 0:
return -1
prev_device = current_device()
if device != prev_device:
set_device(device)
return prev_device
class device(object):
"""Context-manager that changes the selected device.
Args:
device (torch.device or int): device index to select. It's a no-op if
this argument is a negative integer or ``None``.
"""
def __init__(self, device: Any):
self.idx = _get_device_index(device, optional=True)
self.prev_idx = -1
def __enter__(self):
self.prev_idx = _exchange_device(self.idx)
def __exit__(self, *args):
_exchange_device(self.prev_idx)
return False
def is_available() -> bool:
r"""Returns whether vacc is available."""
return device_count() > 0
def is_bf16_supported() -> bool:
r"""Returns a bool indicating if the current vacc device supports dtype bfloat16"""
return True
def current_device() -> int:
r"""Returns the index of a currently selected vacc device."""
_lazy_init()
return _torch_vacc._current_device()
def set_device(device: torch.device):
device_index = _get_device_index(device, optional=True)
if device_index >= 0:
_torch_vacc._set_device(device_index)
def get_device_capability(device=None):
r"""Query the minor and major data of device. Cann does not
have a corresponding concept and is not supported. By default, it returns None
"""
_infos = "torch.vacc.get_device_capability isn't implemented! Please do the version check in other ways, Unlike CUDA major,min"
raise AssertionError(_infos)
def get_device_name(device_name=None):
device_id = _get_device_index(device_name, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id")
_lazy_init()
device_prop = _torch_vacc._vacc_getDeviceProperties(device_id)
return device_prop.name
def get_device_properties(device_name=None):
device_id = _get_device_index(device_name, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id")
_lazy_init()
return _torch_vacc._vacc_getDeviceProperties(device_id)
def device_count():
r"""Returns the number of available vacc devices"""
return _torch_vacc._device_count()
def synchronize(device=None) -> None:
"""Waits for all operations in all streams on a VACC device to complete."""
_lazy_init()
with torch_vacc.vacc.device(device):
return _torch_vacc._device_synchronize()
# Memory management (https://pytorch.org/docs/stable/cuda.html#memory-management)