Files
sglang/python/sglang/srt/torch_memory_saver_adapter.py

93 lines
2.5 KiB
Python

import logging
from abc import ABC
from contextlib import contextmanager
try:
import torch_memory_saver
_memory_saver = torch_memory_saver.torch_memory_saver
import_error = None
except ImportError as e:
import_error = e
pass
logger = logging.getLogger(__name__)
class TorchMemorySaverAdapter(ABC):
@staticmethod
def create(enable: bool):
if enable and import_error is not None:
logger.warning(
"enable_memory_saver is enabled, but "
"torch-memory-saver is not installed. Please install it "
"via `pip3 install torch-memory-saver`. "
)
raise import_error
return (
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
)
def check_validity(self, caller_name):
if not self.enabled:
logger.warning(
f"`{caller_name}` will not save memory because torch_memory_saver is not enabled. "
f"Potential causes: `enable_memory_saver` is false, or torch_memory_saver has installation issues."
)
def configure_subprocess(self):
raise NotImplementedError
def region(self, tag: str, enable_cpu_backup: bool = False):
raise NotImplementedError
def pause(self, tag: str):
raise NotImplementedError
def resume(self, tag: str):
raise NotImplementedError
@property
def enabled(self):
raise NotImplementedError
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
"""Adapter for TorchMemorySaver with tag-based control"""
def configure_subprocess(self):
return torch_memory_saver.configure_subprocess()
def region(self, tag: str, enable_cpu_backup: bool = False):
return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
def pause(self, tag: str):
return _memory_saver.pause(tag=tag)
def resume(self, tag: str):
return _memory_saver.resume(tag=tag)
@property
def enabled(self):
return _memory_saver is not None and _memory_saver.enabled
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
@contextmanager
def configure_subprocess(self):
yield
@contextmanager
def region(self, tag: str, enable_cpu_backup: bool = False):
yield
def pause(self, tag: str):
pass
def resume(self, tag: str):
pass
@property
def enabled(self):
return False