Add device detection and count functions to utils. (#3962)
This commit is contained in:
committed by
GitHub
parent
959a3143fc
commit
76f6c0ebf9
@@ -37,6 +37,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
from importlib.util import find_spec
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from multiprocessing.reduction import ForkingPickler
|
from multiprocessing.reduction import ForkingPickler
|
||||||
@@ -1051,6 +1052,65 @@ def get_device_name(device_id: int = 0) -> str:
|
|||||||
return torch.hpu.get_device_name(device_id)
|
return torch.hpu.get_device_name(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def is_habana_available() -> bool:
|
||||||
|
return find_spec("habana_frameworks") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
def get_device(device_id: Optional[int] = None) -> str:
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
if device_id is None:
|
||||||
|
return "cuda"
|
||||||
|
return "cuda:{}".format(device_id)
|
||||||
|
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
if device_id == None:
|
||||||
|
return "xpu"
|
||||||
|
return "xpu:{}".format(device_id)
|
||||||
|
|
||||||
|
if is_habana_available():
|
||||||
|
try:
|
||||||
|
import habana_frameworks.torch.hpu
|
||||||
|
|
||||||
|
if torch.hpu.is_available():
|
||||||
|
if device_id == None:
|
||||||
|
return "hpu"
|
||||||
|
return "hpu:{}".format(device_id)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_device_count() -> int:
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
return torch.cuda.device_count()
|
||||||
|
except RuntimeError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
try:
|
||||||
|
return torch.xpu.device_count()
|
||||||
|
except RuntimeError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if is_habana_available():
|
||||||
|
try:
|
||||||
|
import habana_frameworks.torch.hpu
|
||||||
|
|
||||||
|
if torch.hpu.is_available():
|
||||||
|
return torch.hpu.device_count()
|
||||||
|
except (ImportError, RuntimeError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return 0 # No accelerators available
|
||||||
|
|
||||||
|
|
||||||
def get_device_core_count(device_id: int = 0) -> int:
|
def get_device_core_count(device_id: int = 0) -> int:
|
||||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||||
@@ -1069,11 +1129,12 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
|||||||
)
|
)
|
||||||
major, minor = int(major), int(minor)
|
major, minor = int(major), int(minor)
|
||||||
|
|
||||||
# TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
|
|
||||||
# Update this once the support is available.
|
|
||||||
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||||
try:
|
try:
|
||||||
major, minor = torch.hpu.get_device_capability(device_id)
|
# TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
|
||||||
|
# Update this once the support is available.
|
||||||
|
# major, minor = torch.hpu.get_device_capability(device_id)
|
||||||
|
major, minor = None, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"An error occurred while getting device capability of hpu: {e}."
|
f"An error occurred while getting device capability of hpu: {e}."
|
||||||
|
|||||||
Reference in New Issue
Block a user