add dynamic register

This commit is contained in:
Chranos
2026-02-05 15:53:43 +08:00
parent 9563c9af0d
commit 92f0016e6f
4 changed files with 244 additions and 14 deletions

View File

@@ -16,9 +16,11 @@ from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
import cloudpickle
import torch.nn as nn
import transformers
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp)
@@ -157,6 +159,11 @@ _SPECULATIVE_DECODING_MODELS = {
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
# Transformers backend models - for custom models with auto_map
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersForCausalLM": ("transformers_backend", "TransformersForCausalLM"),
}
# yapf: enable
_VLLM_MODELS = {
@@ -369,6 +376,62 @@ class _ModelRegistry:
return _try_inspect_model_cls(model_arch, self.models[model_arch])
def _try_resolve_transformers(
self,
architecture: str,
model_path: str,
revision: Optional[str],
trust_remote_code: bool,
hf_config: Optional[object] = None,
) -> Optional[Type[nn.Module]]:
"""
Try to resolve a model architecture using the Transformers backend.
This allows loading custom models that define their own implementation
via the `auto_map` field in config.json.
Returns the loaded model class if successful, None otherwise.
"""
# Check if architecture is in transformers
model_module = getattr(transformers, architecture, None)
# Get auto_map from hf_config
auto_map: Dict[str, str] = {}
if hf_config is not None:
auto_map = getattr(hf_config, "auto_map", None) or {}
if model_module is None and auto_map:
# Try to load from auto_map
# First, ensure config class is loaded
for prefix in ("AutoConfig", "AutoModel"):
for name, module in auto_map.items():
if name.startswith(prefix):
try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=False,
)
# Now try to load the model class
for name, module in auto_map.items():
if name.startswith("AutoModel"):
model_module = try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=True,
)
if model_module is not None:
logger.info(
"Loaded custom model class %s from auto_map",
model_module.__name__
)
return model_module
return model_module
def _normalize_archs(
self,
architectures: Union[str, List[str]],
@@ -383,6 +446,10 @@ class _ModelRegistry:
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
model_path: Optional[str] = None,
revision: Optional[str] = None,
trust_remote_code: bool = False,
hf_config: Optional[object] = None,
) -> _ModelInfo:
architectures = self._normalize_archs(architectures)
@@ -391,11 +458,25 @@ class _ModelRegistry:
if model_info is not None:
return model_info
# Fallback: try to resolve using transformers backend (auto_map)
if model_path and trust_remote_code and hf_config:
for arch in architectures:
model_cls = self._try_resolve_transformers(
arch, model_path, revision, trust_remote_code, hf_config
)
if model_cls is not None:
# Create ModelInfo from the dynamically loaded class
return _ModelInfo.from_model_cls(model_cls)
return self._raise_for_unsupported(architectures)
def resolve_model_cls(
self,
architectures: Union[str, List[str]],
model_path: Optional[str] = None,
revision: Optional[str] = None,
trust_remote_code: bool = False,
hf_config: Optional[object] = None,
) -> Tuple[Type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
@@ -404,39 +485,88 @@ class _ModelRegistry:
if model_cls is not None:
return (model_cls, arch)
# Fallback: try to resolve using transformers backend (auto_map)
if model_path and trust_remote_code and hf_config:
for arch in architectures:
model_cls = self._try_resolve_transformers(
arch, model_path, revision, trust_remote_code, hf_config
)
if model_cls is not None:
return (model_cls, arch)
return self._raise_for_unsupported(architectures)
def is_text_generation_model(
self,
architectures: Union[str, List[str]],
model_path: Optional[str] = None,
revision: Optional[str] = None,
trust_remote_code: bool = False,
hf_config: Optional[object] = None,
) -> bool:
return self.inspect_model_cls(architectures).is_text_generation_model
return self.inspect_model_cls(
architectures, model_path, revision, trust_remote_code, hf_config
).is_text_generation_model
def is_embedding_model(
self,
architectures: Union[str, List[str]],
model_path: Optional[str] = None,
revision: Optional[str] = None,
trust_remote_code: bool = False,
hf_config: Optional[object] = None,
) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model
return self.inspect_model_cls(
architectures, model_path, revision, trust_remote_code, hf_config
).is_embedding_model
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
model_path: Optional[str] = None,
revision: Optional[str] = None,
trust_remote_code: bool = False,
hf_config: Optional[object] = None,
) -> bool:
return self.inspect_model_cls(architectures).supports_multimodal
return self.inspect_model_cls(
architectures, model_path, revision, trust_remote_code, hf_config
).supports_multimodal
def is_pp_supported_model(
self,
architectures: Union[str, List[str]],
model_path: Optional[str] = None,
revision: Optional[str] = None,
trust_remote_code: bool = False,
hf_config: Optional[object] = None,
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
return self.inspect_model_cls(
architectures, model_path, revision, trust_remote_code, hf_config
).supports_pp
def model_has_inner_state(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state
def model_has_inner_state(
self,
architectures: Union[str, List[str]],
model_path: Optional[str] = None,
revision: Optional[str] = None,
trust_remote_code: bool = False,
hf_config: Optional[object] = None,
) -> bool:
return self.inspect_model_cls(
architectures, model_path, revision, trust_remote_code, hf_config
).has_inner_state
def is_attention_free_model(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free
def is_attention_free_model(
self,
architectures: Union[str, List[str]],
model_path: Optional[str] = None,
revision: Optional[str] = None,
trust_remote_code: bool = False,
hf_config: Optional[object] = None,
) -> bool:
return self.inspect_model_cls(
architectures, model_path, revision, trust_remote_code, hf_config
).is_attention_free
ModelRegistry = _ModelRegistry({