testing dynamic register
This commit is contained in:
@@ -160,9 +160,11 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
}
|
||||
|
||||
# Transformers backend models - for custom models with auto_map
|
||||
# Transformers backend models - wrapper classes for custom HuggingFace models
|
||||
# These provide the vLLM interface for models loaded via auto_map
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersForCausalLM": ("transformers_backend", "TransformersForCausalLM"),
|
||||
# Text generation models
|
||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
@@ -171,6 +173,7 @@ _VLLM_MODELS = {
|
||||
**_EMBEDDING_MODELS,
|
||||
**_MULTIMODAL_MODELS,
|
||||
**_SPECULATIVE_DECODING_MODELS,
|
||||
**_TRANSFORMERS_BACKEND_MODELS,
|
||||
}
|
||||
|
||||
# Models not supported by ROCm.
|
||||
@@ -383,54 +386,77 @@ class _ModelRegistry:
|
||||
revision: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
hf_config: Optional[object] = None,
|
||||
) -> Optional[Type[nn.Module]]:
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Try to resolve a model architecture using the Transformers backend.
|
||||
This allows loading custom models that define their own implementation
|
||||
via the `auto_map` field in config.json.
|
||||
|
||||
Returns the loaded model class if successful, None otherwise.
|
||||
Returns the vLLM wrapper architecture name (e.g. "TransformersForCausalLM")
|
||||
if the model can be loaded via auto_map, None otherwise.
|
||||
"""
|
||||
# Check if architecture is in transformers
|
||||
# If architecture is already a transformers backend model, return it
|
||||
if architecture in _TRANSFORMERS_BACKEND_MODELS:
|
||||
return architecture
|
||||
|
||||
# Check if architecture exists in transformers library
|
||||
model_module = getattr(transformers, architecture, None)
|
||||
if model_module is not None:
|
||||
# Model exists in transformers, can use TransformersForCausalLM wrapper
|
||||
logger.info(
|
||||
"Architecture %s found in transformers library, "
|
||||
"using TransformersForCausalLM wrapper",
|
||||
architecture
|
||||
)
|
||||
return "TransformersForCausalLM"
|
||||
|
||||
# Get auto_map from hf_config
|
||||
auto_map: Dict[str, str] = {}
|
||||
if hf_config is not None:
|
||||
auto_map = getattr(hf_config, "auto_map", None) or {}
|
||||
|
||||
if model_module is None and auto_map:
|
||||
# Try to load from auto_map
|
||||
# First, ensure config class is loaded
|
||||
for prefix in ("AutoConfig", "AutoModel"):
|
||||
for name, module in auto_map.items():
|
||||
if name.startswith(prefix):
|
||||
try_get_class_from_dynamic_module(
|
||||
module,
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
warn_on_fail=False,
|
||||
)
|
||||
|
||||
# Now try to load the model class
|
||||
for name, module in auto_map.items():
|
||||
if name.startswith("AutoModel"):
|
||||
model_module = try_get_class_from_dynamic_module(
|
||||
module,
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
warn_on_fail=True,
|
||||
)
|
||||
if model_module is not None:
|
||||
logger.info(
|
||||
"Loaded custom model class %s from auto_map",
|
||||
model_module.__name__
|
||||
)
|
||||
return model_module
|
||||
if not auto_map:
|
||||
return None
|
||||
|
||||
return model_module
|
||||
# Try to load from auto_map to verify it works
|
||||
# First, ensure config class is loaded
|
||||
for name, module in auto_map.items():
|
||||
if name.startswith("AutoConfig"):
|
||||
try_get_class_from_dynamic_module(
|
||||
module,
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
warn_on_fail=False,
|
||||
)
|
||||
|
||||
# Check if auto_map has a model class we can use
|
||||
# Priority: AutoModelForCausalLM > AutoModelForSeq2SeqLM > AutoModel
|
||||
auto_model_keys = sorted(
|
||||
[k for k in auto_map.keys() if k.startswith("AutoModel")],
|
||||
key=lambda x: (0 if "ForCausalLM" in x else (1 if "ForSeq2Seq" in x else 2))
|
||||
)
|
||||
|
||||
for name in auto_model_keys:
|
||||
module = auto_map[name]
|
||||
model_cls = try_get_class_from_dynamic_module(
|
||||
module,
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
warn_on_fail=True,
|
||||
)
|
||||
if model_cls is not None:
|
||||
logger.info(
|
||||
"Found custom model class %s from auto_map[%s], "
|
||||
"using TransformersForCausalLM wrapper",
|
||||
model_cls.__name__,
|
||||
name
|
||||
)
|
||||
# Return the wrapper architecture, not the actual class
|
||||
return "TransformersForCausalLM"
|
||||
|
||||
return None
|
||||
|
||||
def _normalize_archs(
|
||||
self,
|
||||
@@ -461,12 +487,14 @@ class _ModelRegistry:
|
||||
# Fallback: try to resolve using transformers backend (auto_map)
|
||||
if model_path and trust_remote_code and hf_config:
|
||||
for arch in architectures:
|
||||
model_cls = self._try_resolve_transformers(
|
||||
wrapper_arch = self._try_resolve_transformers(
|
||||
arch, model_path, revision, trust_remote_code, hf_config
|
||||
)
|
||||
if model_cls is not None:
|
||||
# Create ModelInfo from the dynamically loaded class
|
||||
return _ModelInfo.from_model_cls(model_cls)
|
||||
if wrapper_arch is not None:
|
||||
# Use the wrapper architecture's ModelInfo
|
||||
model_info = self._try_inspect_model_cls(wrapper_arch)
|
||||
if model_info is not None:
|
||||
return model_info
|
||||
|
||||
return self._raise_for_unsupported(architectures)
|
||||
|
||||
@@ -488,11 +516,14 @@ class _ModelRegistry:
|
||||
# Fallback: try to resolve using transformers backend (auto_map)
|
||||
if model_path and trust_remote_code and hf_config:
|
||||
for arch in architectures:
|
||||
model_cls = self._try_resolve_transformers(
|
||||
wrapper_arch = self._try_resolve_transformers(
|
||||
arch, model_path, revision, trust_remote_code, hf_config
|
||||
)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
if wrapper_arch is not None:
|
||||
model_cls = self._try_load_model_cls(wrapper_arch)
|
||||
if model_cls is not None:
|
||||
# Return wrapper class but keep original architecture name
|
||||
return (model_cls, arch)
|
||||
|
||||
return self._raise_for_unsupported(architectures)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user