[Model] Support DeepSeek-V4
This commit is contained in:
114
tools/ray_mlu/device_manager/__init__.py
Normal file
114
tools/ray_mlu/device_manager/__init__.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
import ray._private.ray_constants as ray_constants
|
||||
from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager
|
||||
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager
|
||||
from ray.air._internal.device_manager.npu import NPUTorchDeviceManager
|
||||
from ray.air._internal.device_manager.mlu import MLUTorchDeviceManager
|
||||
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager
|
||||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use MLUTorchDeviceManager when key="GPU"
|
||||
'''
|
||||
SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = {
|
||||
ray_constants.GPU: MLUTorchDeviceManager,
|
||||
ray_constants.HPU: HPUTorchDeviceManager,
|
||||
ray_constants.NPU: NPUTorchDeviceManager,
|
||||
}
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None:
|
||||
if backend == "hccl":
|
||||
# The name for the communication backend of Habana and torch-npu is the same.
|
||||
HPUTorchDeviceManager.register_custom_torch_dist_backend()
|
||||
|
||||
NPUTorchDeviceManager.register_custom_torch_dist_backend()
|
||||
|
||||
|
||||
_torch_device_manager = None
|
||||
_torch_device_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_torch_device_manager_by_context() -> TorchDeviceManager:
|
||||
global _torch_device_manager
|
||||
|
||||
with _torch_device_manager_lock:
|
||||
if not _torch_device_manager:
|
||||
existing_device_manager_cls = None
|
||||
resources = ray.get_runtime_context().get_accelerator_ids()
|
||||
|
||||
# select correct accelerator type from resources
|
||||
for resource_type, resource_value in resources.items():
|
||||
device_manager_cls = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get(
|
||||
resource_type, None
|
||||
)
|
||||
if resource_value and device_manager_cls:
|
||||
# An error will raise when multiple accelerators are specified.
|
||||
if existing_device_manager_cls:
|
||||
raise RuntimeError(
|
||||
"Unable to determine the appropriate DeviceManager "
|
||||
f"for the specified resources {resources}."
|
||||
)
|
||||
else:
|
||||
existing_device_manager_cls = device_manager_cls
|
||||
|
||||
device_manager_cls = (
|
||||
existing_device_manager_cls or DEFAULT_TORCH_DEVICE_MANAGER_CLS
|
||||
)
|
||||
|
||||
_torch_device_manager = device_manager_cls()
|
||||
|
||||
return _torch_device_manager
|
||||
|
||||
|
||||
def get_torch_device_manager_by_device_type(device_type: str):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use MLUTorchDeviceManager when key="GPU"
|
||||
'''
|
||||
if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda":
|
||||
return MLUTorchDeviceManager()
|
||||
elif device_type.lower() == ray_constants.NPU.lower():
|
||||
return NPUTorchDeviceManager()
|
||||
elif device_type.lower() == ray_constants.HPU.lower():
|
||||
return HPUTorchDeviceManager()
|
||||
elif device_type.lower() == "cpu":
|
||||
return CPUTorchDeviceManager()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
raise RuntimeError(f"Device type {device_type} cannot be recognized.")
|
||||
|
||||
|
||||
__all__ = [
|
||||
TorchDeviceManager,
|
||||
CPUTorchDeviceManager,
|
||||
CUDATorchDeviceManager,
|
||||
HPUTorchDeviceManager,
|
||||
NPUTorchDeviceManager,
|
||||
MLUTorchDeviceManager,
|
||||
register_custom_torch_dist_backend,
|
||||
get_torch_device_manager_by_context,
|
||||
get_torch_device_manager_by_device_type,
|
||||
]
|
||||
Reference in New Issue
Block a user