80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import importlib
|
|
import os
|
|
import sys
|
|
from collections import namedtuple
|
|
from dataclasses import fields
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
Item = namedtuple("Item", ["constructor", "config"])
|
|
|
|
|
|
# credit: snippet used in ClassyVision (and probably other places)
|
|
def import_all_modules(root: str, base_module: str) -> List[str]:
|
|
modules: List[str] = []
|
|
for file in os.listdir(root):
|
|
if file.endswith((".py", ".pyc")) and not file.startswith("_"):
|
|
module = file[: file.find(".py")]
|
|
if module not in sys.modules:
|
|
module_name = ".".join([base_module, module])
|
|
importlib.import_module(module_name)
|
|
modules.append(module_name)
|
|
|
|
return modules
|
|
|
|
|
|
def get_registry_decorator(
|
|
class_registry, name_registry, reference_class, default_config
|
|
) -> Callable[[str, Any], Callable[[Any], Any]]:
|
|
def register_item(name: str, config: Any = default_config):
|
|
"""Registers a subclass.
|
|
|
|
This decorator allows xFormers to instantiate a given subclass
|
|
from a configuration file, even if the class itself is not part of the
|
|
xFormers library."""
|
|
|
|
def register_cls(cls):
|
|
if name in class_registry:
|
|
raise ValueError("Cannot register duplicate item ({})".format(name))
|
|
if not issubclass(cls, reference_class):
|
|
raise ValueError(
|
|
"Item ({}: {}) must extend the base class: {}".format(
|
|
name, cls.__name__, reference_class.__name__
|
|
)
|
|
)
|
|
if cls.__name__ in name_registry:
|
|
raise ValueError(
|
|
"Cannot register item with duplicate class name ({})".format(
|
|
cls.__name__
|
|
)
|
|
)
|
|
|
|
class_registry[name] = Item(constructor=cls, config=config)
|
|
name_registry.add(cls.__name__)
|
|
return cls
|
|
|
|
return register_cls
|
|
|
|
return register_item
|
|
|
|
|
|
def generate_matching_config(superset: Dict[str, Any], config_class: Any) -> Any:
|
|
"""Given a superset of the inputs and a reference config class,
|
|
return exactly the needed config"""
|
|
|
|
# Extract the required fields
|
|
field_names = list(map(lambda x: x.name, fields(config_class)))
|
|
subset = {k: v for k, v in superset.items() if k in field_names}
|
|
|
|
# The missing fields get Noned
|
|
for k in field_names:
|
|
if k not in subset.keys():
|
|
subset[k] = None
|
|
|
|
return config_class(**subset)
|