[Auto Sync] Update registry.py (20250915) (#10484)
Co-authored-by: cctry <shiyang@x.ai>
This commit is contained in:
@@ -17,6 +17,18 @@ class _ModelRegistry:
|
|||||||
# Keyed by model_arch
|
# Keyed by model_arch
|
||||||
models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
|
models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def register(self, package_name: str, overwrite: bool = False):
|
||||||
|
new_models = import_model_classes(package_name)
|
||||||
|
if overwrite:
|
||||||
|
self.models.update(new_models)
|
||||||
|
else:
|
||||||
|
for arch, cls in new_models.items():
|
||||||
|
if arch in self.models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architecture {arch} already registered. Set overwrite=True to replace."
|
||||||
|
)
|
||||||
|
self.models[arch] = cls
|
||||||
|
|
||||||
def get_supported_archs(self) -> AbstractSet[str]:
|
def get_supported_archs(self) -> AbstractSet[str]:
|
||||||
return self.models.keys()
|
return self.models.keys()
|
||||||
|
|
||||||
@@ -74,9 +86,8 @@ class _ModelRegistry:
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def import_model_classes():
|
def import_model_classes(package_name: str):
|
||||||
model_arch_name_to_cls = {}
|
model_arch_name_to_cls = {}
|
||||||
package_name = "sglang.srt.models"
|
|
||||||
package = importlib.import_module(package_name)
|
package = importlib.import_module(package_name)
|
||||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||||
if not ispkg:
|
if not ispkg:
|
||||||
@@ -104,4 +115,5 @@ def import_model_classes():
|
|||||||
return model_arch_name_to_cls
|
return model_arch_name_to_cls
|
||||||
|
|
||||||
|
|
||||||
ModelRegistry = _ModelRegistry(import_model_classes())
|
ModelRegistry = _ModelRegistry()
|
||||||
|
ModelRegistry.register("sglang.srt.models")
|
||||||
|
|||||||
Reference in New Issue
Block a user