Files
2025-08-05 19:02:46 +08:00

80 lines
2.7 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
import sys
from collections import namedtuple
from dataclasses import fields
from typing import Any, Callable, Dict, List
Item = namedtuple("Item", ["constructor", "config"])
# credit: snippet used in ClassyVision (and probably other places)
def import_all_modules(root: str, base_module: str) -> List[str]:
modules: List[str] = []
for file in os.listdir(root):
if file.endswith((".py", ".pyc")) and not file.startswith("_"):
module = file[: file.find(".py")]
if module not in sys.modules:
module_name = ".".join([base_module, module])
importlib.import_module(module_name)
modules.append(module_name)
return modules
def get_registry_decorator(
class_registry, name_registry, reference_class, default_config
) -> Callable[[str, Any], Callable[[Any], Any]]:
def register_item(name: str, config: Any = default_config):
"""Registers a subclass.
This decorator allows xFormers to instantiate a given subclass
from a configuration file, even if the class itself is not part of the
xFormers library."""
def register_cls(cls):
if name in class_registry:
raise ValueError("Cannot register duplicate item ({})".format(name))
if not issubclass(cls, reference_class):
raise ValueError(
"Item ({}: {}) must extend the base class: {}".format(
name, cls.__name__, reference_class.__name__
)
)
if cls.__name__ in name_registry:
raise ValueError(
"Cannot register item with duplicate class name ({})".format(
cls.__name__
)
)
class_registry[name] = Item(constructor=cls, config=config)
name_registry.add(cls.__name__)
return cls
return register_cls
return register_item
def generate_matching_config(superset: Dict[str, Any], config_class: Any) -> Any:
"""Given a superset of the inputs and a reference config class,
return exactly the needed config"""
# Extract the required fields
field_names = list(map(lambda x: x.name, fields(config_class)))
subset = {k: v for k, v in superset.items() if k in field_names}
# The missing fields get Noned
for k in field_names:
if k not in subset.keys():
subset[k] = None
return config_class(**subset)