Files
2026-02-04 17:22:39 +08:00

142 lines
3.9 KiB
Python

import enum
import random
from typing import NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
TPU = enum.auto()
HPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
OPENVINO = enum.auto()
MLU = enum.auto()
UNSPECIFIED = enum.auto()
class DeviceCapability(NamedTuple):
major: int
minor: int
def as_version_str(self) -> str:
return f"{self.major}.{self.minor}"
def to_int(self) -> int:
"""
Express device capability as an integer ``<major><minor>``.
It is assumed that the minor version is always a single digit.
"""
assert 0 <= self.minor < 10
return self.major * 10 + self.minor
class Platform:
_enum: PlatformEnum
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM
def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU
def is_hpu(self) -> bool:
return self._enum == PlatformEnum.HPU
def is_xpu(self) -> bool:
return self._enum == PlatformEnum.XPU
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON
def is_openvino(self) -> bool:
return self._enum == PlatformEnum.OPENVINO
def is_mlu(self) -> bool:
return self._enum == PlatformEnum.MLU
def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def get_device_capability(
cls,
device_id: int = 0,
) -> Optional[DeviceCapability]:
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
@staticmethod
def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]:
return None
@classmethod
def has_device_capability(
cls,
capability: Union[Tuple[int, int], int],
device_id: int = 0,
) -> bool:
"""
Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be:
- A tuple ``(major, minor)``.
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
if isinstance(capability, tuple):
return current_capability >= capability
return current_capability.to_int() >= capability
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Get the total memory of a device in bytes."""
raise NotImplementedError
@classmethod
def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return torch.inference_mode(mode=True)
@classmethod
def seed_everything(cls, seed: int) -> None:
"""
Set the seed of each random module.
`torch.manual_seed` will set seed on all devices.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED