Fix cache modules of triton import error (#7832)
This commit is contained in:
@@ -83,12 +83,7 @@ from torch.func import functional_call
|
|||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
from torch.utils._contextlib import _DecoratorContextManager
|
from torch.utils._contextlib import _DecoratorContextManager
|
||||||
from triton.runtime.cache import (
|
from triton.runtime.cache import FileCacheManager
|
||||||
FileCacheManager,
|
|
||||||
default_cache_dir,
|
|
||||||
default_dump_dir,
|
|
||||||
default_override_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -923,18 +918,41 @@ class CustomCacheManager(FileCacheManager):
|
|||||||
|
|
||||||
self.key = key
|
self.key = key
|
||||||
self.lock_path = None
|
self.lock_path = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
module_path = "triton.runtime.cache"
|
||||||
|
cache_module = importlib.import_module(module_path)
|
||||||
|
|
||||||
|
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
|
||||||
|
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
|
||||||
|
default_override_dir = getattr(cache_module, "default_override_dir", None)
|
||||||
|
except (ModuleNotFoundError, AttributeError) as e:
|
||||||
|
default_cache_dir = None
|
||||||
|
default_dump_dir = None
|
||||||
|
default_override_dir = None
|
||||||
|
|
||||||
if dump:
|
if dump:
|
||||||
self.cache_dir = default_dump_dir()
|
self.cache_dir = (
|
||||||
|
default_dump_dir()
|
||||||
|
if default_dump_dir is not None
|
||||||
|
else os.path.join(Path.home(), ".triton", "dump")
|
||||||
|
)
|
||||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||||
os.makedirs(self.cache_dir, exist_ok=True)
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
elif override:
|
elif override:
|
||||||
self.cache_dir = default_override_dir()
|
self.cache_dir = (
|
||||||
|
default_override_dir()
|
||||||
|
if default_override_dir is not None
|
||||||
|
else os.path.join(Path.home(), ".triton", "override")
|
||||||
|
)
|
||||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
else:
|
else:
|
||||||
# create cache directory if it doesn't exist
|
# create cache directory if it doesn't exist
|
||||||
self.cache_dir = (
|
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
|
||||||
os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
|
default_cache_dir()
|
||||||
|
if default_cache_dir is not None
|
||||||
|
else os.path.join(Path.home(), ".triton", "cache")
|
||||||
)
|
)
|
||||||
if self.cache_dir:
|
if self.cache_dir:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user