93 lines
2.5 KiB
Python
93 lines
2.5 KiB
Python
import logging
|
|
from abc import ABC
|
|
from contextlib import contextmanager
|
|
|
|
try:
|
|
import torch_memory_saver
|
|
|
|
_memory_saver = torch_memory_saver.torch_memory_saver
|
|
import_error = None
|
|
except ImportError as e:
|
|
import_error = e
|
|
pass
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TorchMemorySaverAdapter(ABC):
|
|
@staticmethod
|
|
def create(enable: bool):
|
|
if enable and import_error is not None:
|
|
logger.warning(
|
|
"enable_memory_saver is enabled, but "
|
|
"torch-memory-saver is not installed. Please install it "
|
|
"via `pip3 install torch-memory-saver`. "
|
|
)
|
|
raise import_error
|
|
return (
|
|
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
|
|
)
|
|
|
|
def check_validity(self, caller_name):
|
|
if not self.enabled:
|
|
logger.warning(
|
|
f"`{caller_name}` will not save memory because torch_memory_saver is not enabled. "
|
|
f"Potential causes: `enable_memory_saver` is false, or torch_memory_saver has installation issues."
|
|
)
|
|
|
|
def configure_subprocess(self):
|
|
raise NotImplementedError
|
|
|
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
|
raise NotImplementedError
|
|
|
|
def pause(self, tag: str):
|
|
raise NotImplementedError
|
|
|
|
def resume(self, tag: str):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def enabled(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
|
"""Adapter for TorchMemorySaver with tag-based control"""
|
|
|
|
def configure_subprocess(self):
|
|
return torch_memory_saver.configure_subprocess()
|
|
|
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
|
return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
|
|
|
|
def pause(self, tag: str):
|
|
return _memory_saver.pause(tag=tag)
|
|
|
|
def resume(self, tag: str):
|
|
return _memory_saver.resume(tag=tag)
|
|
|
|
@property
|
|
def enabled(self):
|
|
return _memory_saver is not None and _memory_saver.enabled
|
|
|
|
|
|
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
|
@contextmanager
|
|
def configure_subprocess(self):
|
|
yield
|
|
|
|
@contextmanager
|
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
|
yield
|
|
|
|
def pause(self, tag: str):
|
|
pass
|
|
|
|
def resume(self, tag: str):
|
|
pass
|
|
|
|
@property
|
|
def enabled(self):
|
|
return False
|