init
This commit is contained in:
106
torch_vacc/vacc/_device.py
Normal file
106
torch_vacc/vacc/_device.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Device information
|
||||
# replacing `torch.cuda.func`` with `torch_vacc.vacc.func`.
|
||||
# see https://pytorch.org/docs/stable/cuda.html
|
||||
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch_vacc
|
||||
from torch._utils import _get_device_index
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
from .lazy_initialize import _lazy_init
|
||||
|
||||
if hasattr(_torch_vacc, "_exchange_device"):
|
||||
_exchange_device = _torch_vacc._exchange_device
|
||||
else:
|
||||
|
||||
def _exchange_device(device: int) -> int:
|
||||
return _torch_vacc._exchange_device()
|
||||
if device < 0:
|
||||
return -1
|
||||
prev_device = current_device()
|
||||
if device != prev_device:
|
||||
set_device(device)
|
||||
return prev_device
|
||||
|
||||
|
||||
class device(object):
|
||||
"""Context-manager that changes the selected device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int): device index to select. It's a no-op if
|
||||
this argument is a negative integer or ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self, device: Any):
|
||||
self.idx = _get_device_index(device, optional=True)
|
||||
self.prev_idx = -1
|
||||
|
||||
def __enter__(self):
|
||||
self.prev_idx = _exchange_device(self.idx)
|
||||
|
||||
def __exit__(self, *args):
|
||||
_exchange_device(self.prev_idx)
|
||||
return False
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Returns whether vacc is available."""
|
||||
return device_count() > 0
|
||||
|
||||
def is_bf16_supported() -> bool:
|
||||
r"""Returns a bool indicating if the current vacc device supports dtype bfloat16"""
|
||||
return True
|
||||
|
||||
def current_device() -> int:
|
||||
r"""Returns the index of a currently selected vacc device."""
|
||||
_lazy_init()
|
||||
return _torch_vacc._current_device()
|
||||
|
||||
|
||||
def set_device(device: torch.device):
|
||||
device_index = _get_device_index(device, optional=True)
|
||||
if device_index >= 0:
|
||||
_torch_vacc._set_device(device_index)
|
||||
|
||||
|
||||
def get_device_capability(device=None):
|
||||
r"""Query the minor and major data of device. Cann does not
|
||||
have a corresponding concept and is not supported. By default, it returns None
|
||||
"""
|
||||
_infos = "torch.vacc.get_device_capability isn't implemented! Please do the version check in other ways, Unlike CUDA major,min"
|
||||
raise AssertionError(_infos)
|
||||
|
||||
|
||||
def get_device_name(device_name=None):
|
||||
device_id = _get_device_index(device_name, optional=True)
|
||||
if device_id < 0 or device_id >= device_count():
|
||||
raise AssertionError("Invalid device id")
|
||||
_lazy_init()
|
||||
device_prop = _torch_vacc._vacc_getDeviceProperties(device_id)
|
||||
return device_prop.name
|
||||
|
||||
|
||||
def get_device_properties(device_name=None):
|
||||
device_id = _get_device_index(device_name, optional=True)
|
||||
if device_id < 0 or device_id >= device_count():
|
||||
raise AssertionError("Invalid device id")
|
||||
_lazy_init()
|
||||
return _torch_vacc._vacc_getDeviceProperties(device_id)
|
||||
|
||||
|
||||
def device_count():
|
||||
r"""Returns the number of available vacc devices"""
|
||||
return _torch_vacc._device_count()
|
||||
|
||||
|
||||
def synchronize(device=None) -> None:
|
||||
"""Waits for all operations in all streams on a VACC device to complete."""
|
||||
_lazy_init()
|
||||
with torch_vacc.vacc.device(device):
|
||||
return _torch_vacc._device_synchronize()
|
||||
|
||||
|
||||
# Memory management (https://pytorch.org/docs/stable/cuda.html#memory-management)
|
||||
Reference in New Issue
Block a user