104 lines
3.4 KiB
Python
104 lines
3.4 KiB
Python
import os
|
|
from importlib.util import find_spec
|
|
from typing import List, Union
|
|
|
|
import torch
|
|
|
|
import ray
|
|
import ray._private.ray_constants as ray_constants
|
|
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager
|
|
from ray._private.accelerators.mlu import MLU_VISIBLE_DEVICES_ENV_VAR
|
|
|
|
|
|
def is_package_present(package_name: str) -> bool:
|
|
try:
|
|
return find_spec(package_name) is not None
|
|
except ModuleNotFoundError:
|
|
return False
|
|
|
|
|
|
MLU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_mlu")
|
|
|
|
|
|
if MLU_TORCH_PACKAGE_AVAILABLE:
|
|
import torch_mlu # noqa: F401
|
|
|
|
|
|
class MLUTorchDeviceManager(TorchDeviceManager):
|
|
"""Cambricon MLU device manager"""
|
|
|
|
@staticmethod
|
|
def register_custom_torch_dist_backend():
|
|
if MLU_TORCH_PACKAGE_AVAILABLE:
|
|
import torch_mlu # noqa: F401, F811
|
|
|
|
def is_available(self) -> bool:
|
|
if not MLU_TORCH_PACKAGE_AVAILABLE:
|
|
return False
|
|
|
|
return torch.mlu.is_available()
|
|
|
|
def get_devices(self) -> List[torch.device]:
|
|
"""Gets the correct torch device list configured for this process.
|
|
Returns a list of torch MLU devices allocated for the current worker.
|
|
If no MLUs are assigned, then it returns a list with a single CPU device.
|
|
"""
|
|
if MLU_TORCH_PACKAGE_AVAILABLE and torch.mlu.is_available():
|
|
mlu_ids = [
|
|
str(id)
|
|
for id in ray.get_runtime_context().get_accelerator_ids()[
|
|
ray_constants.GPU
|
|
]
|
|
]
|
|
|
|
device_ids = []
|
|
|
|
if len(mlu_ids) > 0:
|
|
mlu_visible_str = os.environ.get(MLU_VISIBLE_DEVICES_ENV_VAR, "")
|
|
if mlu_visible_str and mlu_visible_str != "NoDevFiles":
|
|
mlu_visible_list = mlu_visible_str.split(",")
|
|
else:
|
|
mlu_visible_list = []
|
|
|
|
for mlu_id in mlu_ids:
|
|
try:
|
|
device_ids.append(mlu_visible_list.index(mlu_id))
|
|
except IndexError:
|
|
raise RuntimeError(
|
|
"MLU_VISIBLE_DEVICES set incorrectly. "
|
|
f"Got {mlu_visible_str}, expected to include {mlu_id}. "
|
|
"Did you override the `MLU_VISIBLE_DEVICES` "
|
|
"environment variable?"
|
|
)
|
|
else:
|
|
# If called on the driver or outside of Ray Train, return the
|
|
# 0th device.
|
|
device_ids.append(0)
|
|
|
|
devices = [torch.device(f"mlu:{device_id}") for device_id in device_ids]
|
|
else:
|
|
raise RuntimeError(
|
|
"Using MLUTorchDeviceManager but torch mlu is not available."
|
|
)
|
|
|
|
return devices
|
|
|
|
def set_device(self, device: Union[torch.device, int]):
|
|
torch.mlu.set_device(device)
|
|
|
|
def supports_stream(self) -> bool:
|
|
"""Validate if the device type support to create a stream"""
|
|
return True
|
|
|
|
def create_stream(self, device):
|
|
"""Create a stream on MLU device"""
|
|
return torch.mlu.Stream(device)
|
|
|
|
def get_stream_context(self, stream):
|
|
"""Get a torch.stream context on MLU device"""
|
|
return torch.mlu.stream(stream)
|
|
|
|
def get_current_stream(self):
|
|
"""Get current stream for MLU device"""
|
|
return torch.mlu.current_stream()
|