testing dynamic register

This commit is contained in:
Chranos
2026-02-05 16:26:24 +08:00
parent 31e7cd3bf9
commit 2cb9f6ce1d
4 changed files with 336 additions and 45 deletions

View File

@@ -160,9 +160,11 @@ _SPECULATIVE_DECODING_MODELS = {
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
# Transformers backend models - for custom models with auto_map
# Transformers backend models - wrapper classes for custom HuggingFace models
# These provide the vLLM interface for models loaded via auto_map
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersForCausalLM": ("transformers_backend", "TransformersForCausalLM"),
# Text generation models
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
}
# yapf: enable
@@ -171,6 +173,7 @@ _VLLM_MODELS = {
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS,
**_TRANSFORMERS_BACKEND_MODELS,
}
# Models not supported by ROCm.
@@ -383,54 +386,77 @@ class _ModelRegistry:
revision: Optional[str],
trust_remote_code: bool,
hf_config: Optional[object] = None,
) -> Optional[Type[nn.Module]]:
) -> Optional[str]:
"""
Try to resolve a model architecture using the Transformers backend.
This allows loading custom models that define their own implementation
via the `auto_map` field in config.json.
Returns the loaded model class if successful, None otherwise.
Returns the vLLM wrapper architecture name (e.g. "TransformersForCausalLM")
if the model can be loaded via auto_map, None otherwise.
"""
# Check if architecture is in transformers
# If architecture is already a transformers backend model, return it
if architecture in _TRANSFORMERS_BACKEND_MODELS:
return architecture
# Check if architecture exists in transformers library
model_module = getattr(transformers, architecture, None)
if model_module is not None:
# Model exists in transformers, can use TransformersForCausalLM wrapper
logger.info(
"Architecture %s found in transformers library, "
"using TransformersForCausalLM wrapper",
architecture
)
return "TransformersForCausalLM"
# Get auto_map from hf_config
auto_map: Dict[str, str] = {}
if hf_config is not None:
auto_map = getattr(hf_config, "auto_map", None) or {}
if model_module is None and auto_map:
# Try to load from auto_map
# First, ensure config class is loaded
for prefix in ("AutoConfig", "AutoModel"):
for name, module in auto_map.items():
if name.startswith(prefix):
try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=False,
)
# Now try to load the model class
for name, module in auto_map.items():
if name.startswith("AutoModel"):
model_module = try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=True,
)
if model_module is not None:
logger.info(
"Loaded custom model class %s from auto_map",
model_module.__name__
)
return model_module
if not auto_map:
return None
return model_module
# Try to load from auto_map to verify it works
# First, ensure config class is loaded
for name, module in auto_map.items():
if name.startswith("AutoConfig"):
try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=False,
)
# Check if auto_map has a model class we can use
# Priority: AutoModelForCausalLM > AutoModelForSeq2SeqLM > AutoModel
auto_model_keys = sorted(
[k for k in auto_map.keys() if k.startswith("AutoModel")],
key=lambda x: (0 if "ForCausalLM" in x else (1 if "ForSeq2Seq" in x else 2))
)
for name in auto_model_keys:
module = auto_map[name]
model_cls = try_get_class_from_dynamic_module(
module,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
warn_on_fail=True,
)
if model_cls is not None:
logger.info(
"Found custom model class %s from auto_map[%s], "
"using TransformersForCausalLM wrapper",
model_cls.__name__,
name
)
# Return the wrapper architecture, not the actual class
return "TransformersForCausalLM"
return None
def _normalize_archs(
self,
@@ -461,12 +487,14 @@ class _ModelRegistry:
# Fallback: try to resolve using transformers backend (auto_map)
if model_path and trust_remote_code and hf_config:
for arch in architectures:
model_cls = self._try_resolve_transformers(
wrapper_arch = self._try_resolve_transformers(
arch, model_path, revision, trust_remote_code, hf_config
)
if model_cls is not None:
# Create ModelInfo from the dynamically loaded class
return _ModelInfo.from_model_cls(model_cls)
if wrapper_arch is not None:
# Use the wrapper architecture's ModelInfo
model_info = self._try_inspect_model_cls(wrapper_arch)
if model_info is not None:
return model_info
return self._raise_for_unsupported(architectures)
@@ -488,11 +516,14 @@ class _ModelRegistry:
# Fallback: try to resolve using transformers backend (auto_map)
if model_path and trust_remote_code and hf_config:
for arch in architectures:
model_cls = self._try_resolve_transformers(
wrapper_arch = self._try_resolve_transformers(
arch, model_path, revision, trust_remote_code, hf_config
)
if model_cls is not None:
return (model_cls, arch)
if wrapper_arch is not None:
model_cls = self._try_load_model_cls(wrapper_arch)
if model_cls is not None:
# Return wrapper class but keep original architecture name
return (model_cls, arch)
return self._raise_for_unsupported(architectures)