testing dynamic register
This commit is contained in:
@@ -353,8 +353,20 @@ class ModelConfig:
|
|||||||
task_support: Dict[_Task, bool] = {
|
task_support: Dict[_Task, bool] = {
|
||||||
# NOTE: Listed from highest to lowest priority,
|
# NOTE: Listed from highest to lowest priority,
|
||||||
# in case the model supports multiple of them
|
# in case the model supports multiple of them
|
||||||
"generate": ModelRegistry.is_text_generation_model(architectures),
|
"generate": ModelRegistry.is_text_generation_model(
|
||||||
"embedding": ModelRegistry.is_embedding_model(architectures),
|
architectures,
|
||||||
|
model_path=self.model,
|
||||||
|
revision=self.revision,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
hf_config=hf_config,
|
||||||
|
),
|
||||||
|
"embedding": ModelRegistry.is_embedding_model(
|
||||||
|
architectures,
|
||||||
|
model_path=self.model,
|
||||||
|
revision=self.revision,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
hf_config=hf_config,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
supported_tasks_lst: List[_Task] = [
|
supported_tasks_lst: List[_Task] = [
|
||||||
task for task, is_supported in task_support.items() if is_supported
|
task for task, is_supported in task_support.items() if is_supported
|
||||||
|
|||||||
@@ -160,9 +160,11 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Transformers backend models - for custom models with auto_map
|
# Transformers backend models - wrapper classes for custom HuggingFace models
|
||||||
|
# These provide the vLLM interface for models loaded via auto_map
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
"TransformersForCausalLM": ("transformers_backend", "TransformersForCausalLM"),
|
# Text generation models
|
||||||
|
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||||
}
|
}
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
@@ -171,6 +173,7 @@ _VLLM_MODELS = {
|
|||||||
**_EMBEDDING_MODELS,
|
**_EMBEDDING_MODELS,
|
||||||
**_MULTIMODAL_MODELS,
|
**_MULTIMODAL_MODELS,
|
||||||
**_SPECULATIVE_DECODING_MODELS,
|
**_SPECULATIVE_DECODING_MODELS,
|
||||||
|
**_TRANSFORMERS_BACKEND_MODELS,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Models not supported by ROCm.
|
# Models not supported by ROCm.
|
||||||
@@ -383,54 +386,77 @@ class _ModelRegistry:
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
hf_config: Optional[object] = None,
|
hf_config: Optional[object] = None,
|
||||||
) -> Optional[Type[nn.Module]]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Try to resolve a model architecture using the Transformers backend.
|
Try to resolve a model architecture using the Transformers backend.
|
||||||
This allows loading custom models that define their own implementation
|
This allows loading custom models that define their own implementation
|
||||||
via the `auto_map` field in config.json.
|
via the `auto_map` field in config.json.
|
||||||
|
|
||||||
Returns the loaded model class if successful, None otherwise.
|
Returns the vLLM wrapper architecture name (e.g. "TransformersForCausalLM")
|
||||||
|
if the model can be loaded via auto_map, None otherwise.
|
||||||
"""
|
"""
|
||||||
# Check if architecture is in transformers
|
# If architecture is already a transformers backend model, return it
|
||||||
|
if architecture in _TRANSFORMERS_BACKEND_MODELS:
|
||||||
|
return architecture
|
||||||
|
|
||||||
|
# Check if architecture exists in transformers library
|
||||||
model_module = getattr(transformers, architecture, None)
|
model_module = getattr(transformers, architecture, None)
|
||||||
|
if model_module is not None:
|
||||||
|
# Model exists in transformers, can use TransformersForCausalLM wrapper
|
||||||
|
logger.info(
|
||||||
|
"Architecture %s found in transformers library, "
|
||||||
|
"using TransformersForCausalLM wrapper",
|
||||||
|
architecture
|
||||||
|
)
|
||||||
|
return "TransformersForCausalLM"
|
||||||
|
|
||||||
# Get auto_map from hf_config
|
# Get auto_map from hf_config
|
||||||
auto_map: Dict[str, str] = {}
|
auto_map: Dict[str, str] = {}
|
||||||
if hf_config is not None:
|
if hf_config is not None:
|
||||||
auto_map = getattr(hf_config, "auto_map", None) or {}
|
auto_map = getattr(hf_config, "auto_map", None) or {}
|
||||||
|
|
||||||
if model_module is None and auto_map:
|
if not auto_map:
|
||||||
# Try to load from auto_map
|
return None
|
||||||
# First, ensure config class is loaded
|
|
||||||
for prefix in ("AutoConfig", "AutoModel"):
|
|
||||||
for name, module in auto_map.items():
|
|
||||||
if name.startswith(prefix):
|
|
||||||
try_get_class_from_dynamic_module(
|
|
||||||
module,
|
|
||||||
model_path,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
revision=revision,
|
|
||||||
warn_on_fail=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now try to load the model class
|
|
||||||
for name, module in auto_map.items():
|
|
||||||
if name.startswith("AutoModel"):
|
|
||||||
model_module = try_get_class_from_dynamic_module(
|
|
||||||
module,
|
|
||||||
model_path,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
revision=revision,
|
|
||||||
warn_on_fail=True,
|
|
||||||
)
|
|
||||||
if model_module is not None:
|
|
||||||
logger.info(
|
|
||||||
"Loaded custom model class %s from auto_map",
|
|
||||||
model_module.__name__
|
|
||||||
)
|
|
||||||
return model_module
|
|
||||||
|
|
||||||
return model_module
|
# Try to load from auto_map to verify it works
|
||||||
|
# First, ensure config class is loaded
|
||||||
|
for name, module in auto_map.items():
|
||||||
|
if name.startswith("AutoConfig"):
|
||||||
|
try_get_class_from_dynamic_module(
|
||||||
|
module,
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
revision=revision,
|
||||||
|
warn_on_fail=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if auto_map has a model class we can use
|
||||||
|
# Priority: AutoModelForCausalLM > AutoModelForSeq2SeqLM > AutoModel
|
||||||
|
auto_model_keys = sorted(
|
||||||
|
[k for k in auto_map.keys() if k.startswith("AutoModel")],
|
||||||
|
key=lambda x: (0 if "ForCausalLM" in x else (1 if "ForSeq2Seq" in x else 2))
|
||||||
|
)
|
||||||
|
|
||||||
|
for name in auto_model_keys:
|
||||||
|
module = auto_map[name]
|
||||||
|
model_cls = try_get_class_from_dynamic_module(
|
||||||
|
module,
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
revision=revision,
|
||||||
|
warn_on_fail=True,
|
||||||
|
)
|
||||||
|
if model_cls is not None:
|
||||||
|
logger.info(
|
||||||
|
"Found custom model class %s from auto_map[%s], "
|
||||||
|
"using TransformersForCausalLM wrapper",
|
||||||
|
model_cls.__name__,
|
||||||
|
name
|
||||||
|
)
|
||||||
|
# Return the wrapper architecture, not the actual class
|
||||||
|
return "TransformersForCausalLM"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _normalize_archs(
|
def _normalize_archs(
|
||||||
self,
|
self,
|
||||||
@@ -461,12 +487,14 @@ class _ModelRegistry:
|
|||||||
# Fallback: try to resolve using transformers backend (auto_map)
|
# Fallback: try to resolve using transformers backend (auto_map)
|
||||||
if model_path and trust_remote_code and hf_config:
|
if model_path and trust_remote_code and hf_config:
|
||||||
for arch in architectures:
|
for arch in architectures:
|
||||||
model_cls = self._try_resolve_transformers(
|
wrapper_arch = self._try_resolve_transformers(
|
||||||
arch, model_path, revision, trust_remote_code, hf_config
|
arch, model_path, revision, trust_remote_code, hf_config
|
||||||
)
|
)
|
||||||
if model_cls is not None:
|
if wrapper_arch is not None:
|
||||||
# Create ModelInfo from the dynamically loaded class
|
# Use the wrapper architecture's ModelInfo
|
||||||
return _ModelInfo.from_model_cls(model_cls)
|
model_info = self._try_inspect_model_cls(wrapper_arch)
|
||||||
|
if model_info is not None:
|
||||||
|
return model_info
|
||||||
|
|
||||||
return self._raise_for_unsupported(architectures)
|
return self._raise_for_unsupported(architectures)
|
||||||
|
|
||||||
@@ -488,11 +516,14 @@ class _ModelRegistry:
|
|||||||
# Fallback: try to resolve using transformers backend (auto_map)
|
# Fallback: try to resolve using transformers backend (auto_map)
|
||||||
if model_path and trust_remote_code and hf_config:
|
if model_path and trust_remote_code and hf_config:
|
||||||
for arch in architectures:
|
for arch in architectures:
|
||||||
model_cls = self._try_resolve_transformers(
|
wrapper_arch = self._try_resolve_transformers(
|
||||||
arch, model_path, revision, trust_remote_code, hf_config
|
arch, model_path, revision, trust_remote_code, hf_config
|
||||||
)
|
)
|
||||||
if model_cls is not None:
|
if wrapper_arch is not None:
|
||||||
return (model_cls, arch)
|
model_cls = self._try_load_model_cls(wrapper_arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
# Return wrapper class but keep original architecture name
|
||||||
|
return (model_cls, arch)
|
||||||
|
|
||||||
return self._raise_for_unsupported(architectures)
|
return self._raise_for_unsupported(architectures)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
"""Wrapper around `transformers` models for vLLM v0.6.2.
|
||||||
|
|
||||||
|
This module provides a simplified Transformers modeling backend that wraps
|
||||||
|
any HuggingFace model with the vLLM interface, enabling support for custom
|
||||||
|
models that define their implementation via `auto_map` in config.json.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from vllm.model_executor.models.transformers.causal import TransformersForCausalLM
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TransformersForCausalLM",
|
||||||
|
]
|
||||||
234
vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py
Normal file
234
vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
"""Transformers modeling backend for causal language models.
|
||||||
|
|
||||||
|
This module provides a wrapper class that enables vLLM to use any HuggingFace
|
||||||
|
causal language model, including custom models that define their implementation
|
||||||
|
via `auto_map` in config.json.
|
||||||
|
|
||||||
|
The key insight is that we use HuggingFace's AutoModelForCausalLM to load the
|
||||||
|
actual model, then wrap it with the vLLM interface (compute_logits, sample, etc).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersForCausalLM(nn.Module):
|
||||||
|
"""
|
||||||
|
A wrapper class that adapts any HuggingFace causal language model
|
||||||
|
to the vLLM interface.
|
||||||
|
|
||||||
|
This class provides:
|
||||||
|
1. forward() - processes input through the model
|
||||||
|
2. compute_logits() - computes output logits
|
||||||
|
3. sample() - samples tokens from logits
|
||||||
|
4. load_weights() - loads model weights
|
||||||
|
|
||||||
|
The actual HuggingFace model is loaded using AutoModelForCausalLM and
|
||||||
|
stored in self.model.
|
||||||
|
|
||||||
|
Interface compliance:
|
||||||
|
- Implements VllmModel protocol (vllm_config init, forward with required args)
|
||||||
|
- Implements VllmModelForTextGeneration protocol (compute_logits, sample)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.config = config.hf_config
|
||||||
|
self.model_config = config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
|
logger.info("Using Transformers modeling backend for %s",
|
||||||
|
config.hf_config.architectures)
|
||||||
|
|
||||||
|
# Load the actual HuggingFace model
|
||||||
|
self._load_hf_model()
|
||||||
|
|
||||||
|
# Setup logits processor and sampler
|
||||||
|
self.logits_processor = LogitsProcessor(
|
||||||
|
self.config.vocab_size,
|
||||||
|
logits_as_input=False,
|
||||||
|
)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def _load_hf_model(self) -> None:
|
||||||
|
"""Load the HuggingFace model using AutoModelForCausalLM."""
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
# We load with minimal config first - weights will be loaded separately
|
||||||
|
# by vLLM's weight loader
|
||||||
|
logger.info("Loading HuggingFace model from config...")
|
||||||
|
|
||||||
|
self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config(
|
||||||
|
self.config,
|
||||||
|
torch_dtype=self.model_config.dtype,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable gradient computation for inference
|
||||||
|
self.model.eval()
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: "AttentionMetadata",
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass through the model.
|
||||||
|
|
||||||
|
This method conforms to the VllmModel protocol by accepting:
|
||||||
|
- input_ids: Token IDs
|
||||||
|
- positions: Position IDs
|
||||||
|
- kv_caches: KV cache tensors (not used in basic HF forward)
|
||||||
|
- attn_metadata: Attention metadata (not used in basic HF forward)
|
||||||
|
|
||||||
|
Note: This is a simplified implementation that does not use vLLM's
|
||||||
|
optimized attention mechanisms. For production use with KV caching,
|
||||||
|
a more sophisticated implementation would be needed.
|
||||||
|
"""
|
||||||
|
# For simplicity, we use HuggingFace's native forward
|
||||||
|
# This won't have vLLM's optimizations but will work
|
||||||
|
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
|
else:
|
||||||
|
model_inputs = {"input_ids": input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids}
|
||||||
|
|
||||||
|
# Position IDs
|
||||||
|
if positions is not None:
|
||||||
|
model_inputs["position_ids"] = positions.unsqueeze(0) if positions.dim() == 1 else positions
|
||||||
|
|
||||||
|
# Run the model
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(
|
||||||
|
**model_inputs,
|
||||||
|
use_cache=False,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get hidden states from the last layer
|
||||||
|
# For CausalLM, we typically want the hidden states before the LM head
|
||||||
|
if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
|
||||||
|
hidden_states = outputs.hidden_states[-1]
|
||||||
|
else:
|
||||||
|
# Fall back to running without output_hidden_states
|
||||||
|
# and getting logits directly
|
||||||
|
hidden_states = outputs.logits
|
||||||
|
if hidden_states.dim() == 3:
|
||||||
|
hidden_states = hidden_states.squeeze(0)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
if hidden_states.dim() == 3:
|
||||||
|
hidden_states = hidden_states.squeeze(0)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute logits from hidden states.
|
||||||
|
|
||||||
|
This method conforms to the VllmModelForTextGeneration protocol.
|
||||||
|
"""
|
||||||
|
# If hidden_states are already logits (from forward), process them
|
||||||
|
if hidden_states.shape[-1] == self.config.vocab_size:
|
||||||
|
logits = hidden_states
|
||||||
|
else:
|
||||||
|
# Apply the LM head
|
||||||
|
logits = self.model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
return self.logits_processor(None, logits, sampling_metadata)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
"""
|
||||||
|
Sample tokens from logits.
|
||||||
|
|
||||||
|
This method conforms to the VllmModelForTextGeneration protocol.
|
||||||
|
"""
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
|
) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Load weights into the model.
|
||||||
|
|
||||||
|
This method loads weights from an iterable of (name, tensor) pairs
|
||||||
|
into the HuggingFace model.
|
||||||
|
"""
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
model_params = dict(self.model.named_parameters())
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
# Try to find the parameter in the model
|
||||||
|
if name in model_params:
|
||||||
|
param = model_params[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
else:
|
||||||
|
# Try common prefixes
|
||||||
|
for prefix in ["model.", ""]:
|
||||||
|
full_name = f"{prefix}{name}" if prefix else name
|
||||||
|
if full_name in model_params:
|
||||||
|
param = model_params[full_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
break
|
||||||
|
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
def is_backend_compatible() -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current model is compatible with the Transformers backend.
|
||||||
|
|
||||||
|
This is a simplified check - in practice, compatibility depends on
|
||||||
|
whether the model follows standard HuggingFace conventions.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
Reference in New Issue
Block a user