diff --git a/python/sglang/srt/models/registry.py b/python/sglang/srt/models/registry.py index 76e042a95..5e2a3c67e 100644 --- a/python/sglang/srt/models/registry.py +++ b/python/sglang/srt/models/registry.py @@ -17,6 +17,18 @@ class _ModelRegistry: # Keyed by model_arch models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict) + def register(self, package_name: str, overwrite: bool = False): + new_models = import_model_classes(package_name) + if overwrite: + self.models.update(new_models) + else: + for arch, cls in new_models.items(): + if arch in self.models: + raise ValueError( + f"Model architecture {arch} already registered. Set overwrite=True to replace." + ) + self.models[arch] = cls + def get_supported_archs(self) -> AbstractSet[str]: return self.models.keys() @@ -74,9 +86,8 @@ class _ModelRegistry: @lru_cache() -def import_model_classes(): +def import_model_classes(package_name: str): model_arch_name_to_cls = {} - package_name = "sglang.srt.models" package = importlib.import_module(package_name) for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): if not ispkg: @@ -104,4 +115,5 @@ def import_model_classes(): return model_arch_name_to_cls -ModelRegistry = _ModelRegistry(import_model_classes()) +ModelRegistry = _ModelRegistry() +ModelRegistry.register("sglang.srt.models")