Fix cache modules of triton import error (#7832)
This commit is contained in:
@@ -83,12 +83,7 @@ from torch.func import functional_call
|
||||
from torch.library import Library
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from torch.utils._contextlib import _DecoratorContextManager
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
default_cache_dir,
|
||||
default_dump_dir,
|
||||
default_override_dir,
|
||||
)
|
||||
from triton.runtime.cache import FileCacheManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -923,18 +918,41 @@ class CustomCacheManager(FileCacheManager):
|
||||
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
|
||||
try:
|
||||
module_path = "triton.runtime.cache"
|
||||
cache_module = importlib.import_module(module_path)
|
||||
|
||||
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
|
||||
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
|
||||
default_override_dir = getattr(cache_module, "default_override_dir", None)
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
default_cache_dir = None
|
||||
default_dump_dir = None
|
||||
default_override_dir = None
|
||||
|
||||
if dump:
|
||||
self.cache_dir = default_dump_dir()
|
||||
self.cache_dir = (
|
||||
default_dump_dir()
|
||||
if default_dump_dir is not None
|
||||
else os.path.join(Path.home(), ".triton", "dump")
|
||||
)
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif override:
|
||||
self.cache_dir = default_override_dir()
|
||||
self.cache_dir = (
|
||||
default_override_dir()
|
||||
if default_override_dir is not None
|
||||
else os.path.join(Path.home(), ".triton", "override")
|
||||
)
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = (
|
||||
os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
|
||||
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
|
||||
default_cache_dir()
|
||||
if default_cache_dir is not None
|
||||
else os.path.join(Path.home(), ".triton", "cache")
|
||||
)
|
||||
if self.cache_dir:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user