673 lines
26 KiB
Python
673 lines
26 KiB
Python
"""
|
|
Whenever you add an architecture to this page, please also update
|
|
`tests/models/registry.py` with example HuggingFace models for it.
|
|
"""
|
|
import importlib
|
|
import os
|
|
import pickle
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from functools import lru_cache
|
|
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
|
|
TypeVar, Union)
|
|
|
|
import cloudpickle
|
|
import torch.nn as nn
|
|
import transformers
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
|
|
|
from .interfaces import (has_inner_state, is_attention_free,
|
|
supports_multimodal, supports_pp)
|
|
from .interfaces_base import is_embedding_model, is_text_generation_model
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Cache for architectures that have already been logged
|
|
_logged_transformers_architectures: set = set()
|
|
|
|
# yapf: disable
|
|
_TEXT_GENERATION_MODELS = {
|
|
# [Decoder-only]
|
|
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
|
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
|
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
|
# baichuan-7b, upper case 'C' in the class name
|
|
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
|
|
# baichuan-13b, lower case 'c' in the class name
|
|
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
|
|
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
|
# ChatGLMModel supports multimodal
|
|
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
|
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
|
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
|
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
|
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
|
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
|
|
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
|
|
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
|
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
|
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
|
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
|
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
|
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
|
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
|
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
|
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
|
|
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
|
|
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
|
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
|
|
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
|
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
|
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
|
|
"Llama4ForConditionalGeneration": ("llama4", "Llama4ForCausalLM"),
|
|
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
# For decapoda-research/llama-*
|
|
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
|
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
|
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
|
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
|
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
|
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
|
# transformers's mpt class has lower case
|
|
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
|
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
|
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
|
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
|
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
|
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
|
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
|
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
|
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
|
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
|
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
|
# QWenLMHeadModel supports multimodal
|
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
|
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
|
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
|
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
|
|
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
|
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
|
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
|
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
|
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
|
# [Encoder-decoder]
|
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
|
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
|
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
|
"HunYuanForCausalLM": ("hunyuan", "HunYuanForCausalLM"),
|
|
}
|
|
|
|
_EMBEDDING_MODELS = {
|
|
# [Text-only]
|
|
"BertModel": ("bert", "BertEmbeddingModel"),
|
|
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
|
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
|
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
|
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
|
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
|
|
**{
|
|
# Multiple models share the same architecture, so we include them all
|
|
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
|
|
if arch == "LlamaForCausalLM"
|
|
},
|
|
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
|
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
|
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
|
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
|
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
|
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
|
# [Multimodal]
|
|
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
|
}
|
|
|
|
_MULTIMODAL_MODELS = {
|
|
# [Decoder-only]
|
|
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
|
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
|
|
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
|
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
|
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
|
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
|
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
|
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
|
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
|
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
|
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
|
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
|
|
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
|
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
|
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
|
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
|
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
|
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
|
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
|
# [Encoder-decoder]
|
|
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
|
}
|
|
|
|
_SPECULATIVE_DECODING_MODELS = {
|
|
"EAGLEModel": ("eagle", "EAGLE"),
|
|
"MedusaModel": ("medusa", "Medusa"),
|
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
|
}
|
|
|
|
# Transformers backend models - wrapper classes for custom HuggingFace models
|
|
# These provide the vLLM interface for models loaded via auto_map
|
|
_TRANSFORMERS_BACKEND_MODELS = {
|
|
# Text generation models
|
|
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
|
}
|
|
# yapf: enable
|
|
|
|
_VLLM_MODELS = {
|
|
**_TEXT_GENERATION_MODELS,
|
|
**_EMBEDDING_MODELS,
|
|
**_MULTIMODAL_MODELS,
|
|
**_SPECULATIVE_DECODING_MODELS,
|
|
**_TRANSFORMERS_BACKEND_MODELS,
|
|
}
|
|
|
|
# Models not supported by ROCm.
|
|
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
|
|
|
# Models partially supported by ROCm.
|
|
# Architecture -> Reason.
|
|
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
|
|
"Triton flash attention. For half-precision SWA support, "
|
|
"please use CK flash attention by setting "
|
|
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
|
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
|
"Qwen2ForCausalLM":
|
|
_ROCM_SWA_REASON,
|
|
"MistralForCausalLM":
|
|
_ROCM_SWA_REASON,
|
|
"MixtralForCausalLM":
|
|
_ROCM_SWA_REASON,
|
|
"PaliGemmaForConditionalGeneration":
|
|
("ROCm flash attention does not yet "
|
|
"fully support 32-bit precision on PaliGemma"),
|
|
"Phi3VForCausalLM":
|
|
("ROCm Triton flash attention may run into compilation errors due to "
|
|
"excessive use of shared memory. If this happens, disable Triton FA "
|
|
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
|
}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _ModelInfo:
|
|
is_text_generation_model: bool
|
|
is_embedding_model: bool
|
|
supports_multimodal: bool
|
|
supports_pp: bool
|
|
has_inner_state: bool
|
|
is_attention_free: bool
|
|
|
|
@staticmethod
|
|
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
|
return _ModelInfo(
|
|
is_text_generation_model=is_text_generation_model(model),
|
|
is_embedding_model=is_embedding_model(model),
|
|
supports_multimodal=supports_multimodal(model),
|
|
supports_pp=supports_pp(model),
|
|
has_inner_state=has_inner_state(model),
|
|
is_attention_free=is_attention_free(model),
|
|
)
|
|
|
|
|
|
class _BaseRegisteredModel(ABC):
|
|
|
|
@abstractmethod
|
|
def inspect_model_cls(self) -> _ModelInfo:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def load_model_cls(self) -> Type[nn.Module]:
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _RegisteredModel(_BaseRegisteredModel):
|
|
"""
|
|
Represents a model that has already been imported in the main process.
|
|
"""
|
|
|
|
interfaces: _ModelInfo
|
|
model_cls: Type[nn.Module]
|
|
|
|
@staticmethod
|
|
def from_model_cls(model_cls: Type[nn.Module]):
|
|
return _RegisteredModel(
|
|
interfaces=_ModelInfo.from_model_cls(model_cls),
|
|
model_cls=model_cls,
|
|
)
|
|
|
|
def inspect_model_cls(self) -> _ModelInfo:
|
|
return self.interfaces
|
|
|
|
def load_model_cls(self) -> Type[nn.Module]:
|
|
return self.model_cls
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _LazyRegisteredModel(_BaseRegisteredModel):
|
|
"""
|
|
Represents a model that has not been imported in the main process.
|
|
"""
|
|
module_name: str
|
|
class_name: str
|
|
|
|
# Performed in another process to avoid initializing CUDA
|
|
def inspect_model_cls(self) -> _ModelInfo:
|
|
return _run_in_subprocess(
|
|
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
|
|
|
|
def load_model_cls(self) -> Type[nn.Module]:
|
|
mod = importlib.import_module(self.module_name)
|
|
return getattr(mod, self.class_name)
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
def _try_load_model_cls(
|
|
model_arch: str,
|
|
model: _BaseRegisteredModel,
|
|
) -> Optional[Type[nn.Module]]:
|
|
if current_platform.is_rocm():
|
|
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
|
raise ValueError(f"Model architecture '{model_arch}' is not "
|
|
"supported by ROCm for now.")
|
|
|
|
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
|
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
|
|
logger.warning(
|
|
"Model architecture '%s' is partially "
|
|
"supported by ROCm: %s", model_arch, msg)
|
|
|
|
try:
|
|
return model.load_model_cls()
|
|
except Exception:
|
|
logger.exception("Error in loading model architecture '%s'",
|
|
model_arch)
|
|
return None
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
def _try_inspect_model_cls(
|
|
model_arch: str,
|
|
model: _BaseRegisteredModel,
|
|
) -> Optional[_ModelInfo]:
|
|
try:
|
|
return model.inspect_model_cls()
|
|
except Exception:
|
|
logger.exception("Error in inspecting model architecture '%s'",
|
|
model_arch)
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class _ModelRegistry:
|
|
# Keyed by model_arch
|
|
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
|
|
|
|
def get_supported_archs(self) -> AbstractSet[str]:
|
|
return self.models.keys()
|
|
|
|
def register_model(
|
|
self,
|
|
model_arch: str,
|
|
model_cls: Union[Type[nn.Module], str],
|
|
) -> None:
|
|
"""
|
|
Register an external model to be used in vLLM.
|
|
|
|
:code:`model_cls` can be either:
|
|
|
|
- A :class:`torch.nn.Module` class directly referencing the model.
|
|
- A string in the format :code:`<module>:<class>` which can be used to
|
|
lazily import the model. This is useful to avoid initializing CUDA
|
|
when importing the model and thus the related error
|
|
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
|
"""
|
|
if model_arch in self.models:
|
|
logger.warning(
|
|
"Model architecture %s is already registered, and will be "
|
|
"overwritten by the new model class %s.", model_arch,
|
|
model_cls)
|
|
|
|
if isinstance(model_cls, str):
|
|
split_str = model_cls.split(":")
|
|
if len(split_str) != 2:
|
|
msg = "Expected a string in the format `<module>:<class>`"
|
|
raise ValueError(msg)
|
|
|
|
model = _LazyRegisteredModel(*split_str)
|
|
else:
|
|
model = _RegisteredModel.from_model_cls(model_cls)
|
|
|
|
self.models[model_arch] = model
|
|
|
|
def _raise_for_unsupported(self, architectures: List[str]):
|
|
all_supported_archs = self.get_supported_archs()
|
|
|
|
if any(arch in all_supported_archs for arch in architectures):
|
|
raise ValueError(
|
|
f"Model architectures {architectures} failed "
|
|
"to be inspected. Please check the logs for more details.")
|
|
|
|
raise ValueError(
|
|
f"Model architectures {architectures} are not supported for now. "
|
|
f"Supported architectures: {all_supported_archs}")
|
|
|
|
def _try_load_model_cls(self,
|
|
model_arch: str) -> Optional[Type[nn.Module]]:
|
|
if model_arch not in self.models:
|
|
return None
|
|
|
|
return _try_load_model_cls(model_arch, self.models[model_arch])
|
|
|
|
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
|
|
if model_arch not in self.models:
|
|
return None
|
|
|
|
return _try_inspect_model_cls(model_arch, self.models[model_arch])
|
|
|
|
def _try_resolve_transformers(
|
|
self,
|
|
architecture: str,
|
|
model_path: str,
|
|
revision: Optional[str],
|
|
trust_remote_code: bool,
|
|
hf_config: Optional[object] = None,
|
|
) -> Optional[str]:
|
|
"""
|
|
Try to resolve a model architecture using the Transformers backend.
|
|
This allows loading custom models that define their own implementation
|
|
via the `auto_map` field in config.json.
|
|
|
|
Returns the vLLM wrapper architecture name (e.g. "TransformersForCausalLM")
|
|
if the model can be loaded via auto_map, None otherwise.
|
|
"""
|
|
# If architecture is already a transformers backend model, return it
|
|
if architecture in _TRANSFORMERS_BACKEND_MODELS:
|
|
return architecture
|
|
|
|
# Check if architecture exists in transformers library
|
|
model_module = getattr(transformers, architecture, None)
|
|
if model_module is not None:
|
|
# Model exists in transformers, can use TransformersForCausalLM wrapper
|
|
# Only log once per architecture to avoid spam
|
|
if architecture not in _logged_transformers_architectures:
|
|
_logged_transformers_architectures.add(architecture)
|
|
logger.info(
|
|
"Architecture %s found in transformers library, "
|
|
"using TransformersForCausalLM wrapper",
|
|
architecture
|
|
)
|
|
return "TransformersForCausalLM"
|
|
|
|
# Get auto_map from hf_config
|
|
auto_map: Dict[str, str] = {}
|
|
if hf_config is not None:
|
|
auto_map = getattr(hf_config, "auto_map", None) or {}
|
|
|
|
if not auto_map:
|
|
return None
|
|
|
|
# Try to load from auto_map to verify it works
|
|
# First, ensure config class is loaded
|
|
for name, module in auto_map.items():
|
|
if name.startswith("AutoConfig"):
|
|
try_get_class_from_dynamic_module(
|
|
module,
|
|
model_path,
|
|
trust_remote_code=trust_remote_code,
|
|
revision=revision,
|
|
warn_on_fail=False,
|
|
)
|
|
|
|
# Check if auto_map has a model class we can use
|
|
# Priority: AutoModelForCausalLM > AutoModelForSeq2SeqLM > AutoModel
|
|
auto_model_keys = sorted(
|
|
[k for k in auto_map.keys() if k.startswith("AutoModel")],
|
|
key=lambda x: (0 if "ForCausalLM" in x else (1 if "ForSeq2Seq" in x else 2))
|
|
)
|
|
|
|
for name in auto_model_keys:
|
|
module = auto_map[name]
|
|
model_cls = try_get_class_from_dynamic_module(
|
|
module,
|
|
model_path,
|
|
trust_remote_code=trust_remote_code,
|
|
revision=revision,
|
|
warn_on_fail=True,
|
|
)
|
|
if model_cls is not None:
|
|
# Only log once per model class to avoid spam
|
|
log_key = f"{model_cls.__name__}_{name}"
|
|
if not hasattr(self, '_logged_custom_models'):
|
|
self._logged_custom_models = set()
|
|
if log_key not in self._logged_custom_models:
|
|
logger.info(
|
|
"Found custom model class %s from auto_map[%s], "
|
|
"using TransformersForCausalLM wrapper",
|
|
model_cls.__name__,
|
|
name
|
|
)
|
|
self._logged_custom_models.add(log_key)
|
|
# Return the wrapper architecture, not the actual class
|
|
return "TransformersForCausalLM"
|
|
|
|
return None
|
|
|
|
def _normalize_archs(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
) -> List[str]:
|
|
if isinstance(architectures, str):
|
|
architectures = [architectures]
|
|
if not architectures:
|
|
logger.warning("No model architectures are specified")
|
|
|
|
return architectures
|
|
|
|
def inspect_model_cls(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
model_path: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
hf_config: Optional[object] = None,
|
|
) -> _ModelInfo:
|
|
architectures = self._normalize_archs(architectures)
|
|
|
|
for arch in architectures:
|
|
model_info = self._try_inspect_model_cls(arch)
|
|
if model_info is not None:
|
|
return model_info
|
|
|
|
# Fallback: try to resolve using transformers backend (auto_map)
|
|
if model_path and trust_remote_code and hf_config:
|
|
for arch in architectures:
|
|
wrapper_arch = self._try_resolve_transformers(
|
|
arch, model_path, revision, trust_remote_code, hf_config
|
|
)
|
|
if wrapper_arch is not None:
|
|
# Use the wrapper architecture's ModelInfo
|
|
model_info = self._try_inspect_model_cls(wrapper_arch)
|
|
if model_info is not None:
|
|
return model_info
|
|
|
|
return self._raise_for_unsupported(architectures)
|
|
|
|
def resolve_model_cls(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
model_path: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
hf_config: Optional[object] = None,
|
|
) -> Tuple[Type[nn.Module], str]:
|
|
architectures = self._normalize_archs(architectures)
|
|
|
|
for arch in architectures:
|
|
model_cls = self._try_load_model_cls(arch)
|
|
if model_cls is not None:
|
|
return (model_cls, arch)
|
|
|
|
# Fallback: try to resolve using transformers backend (auto_map)
|
|
if model_path and trust_remote_code and hf_config:
|
|
for arch in architectures:
|
|
wrapper_arch = self._try_resolve_transformers(
|
|
arch, model_path, revision, trust_remote_code, hf_config
|
|
)
|
|
if wrapper_arch is not None:
|
|
model_cls = self._try_load_model_cls(wrapper_arch)
|
|
if model_cls is not None:
|
|
# Return wrapper class but keep original architecture name
|
|
return (model_cls, arch)
|
|
|
|
return self._raise_for_unsupported(architectures)
|
|
|
|
def is_text_generation_model(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
model_path: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
hf_config: Optional[object] = None,
|
|
) -> bool:
|
|
return self.inspect_model_cls(
|
|
architectures, model_path, revision, trust_remote_code, hf_config
|
|
).is_text_generation_model
|
|
|
|
def is_embedding_model(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
model_path: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
hf_config: Optional[object] = None,
|
|
) -> bool:
|
|
return self.inspect_model_cls(
|
|
architectures, model_path, revision, trust_remote_code, hf_config
|
|
).is_embedding_model
|
|
|
|
def is_multimodal_model(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
model_path: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
hf_config: Optional[object] = None,
|
|
) -> bool:
|
|
return self.inspect_model_cls(
|
|
architectures, model_path, revision, trust_remote_code, hf_config
|
|
).supports_multimodal
|
|
|
|
def is_pp_supported_model(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
model_path: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
hf_config: Optional[object] = None,
|
|
) -> bool:
|
|
return self.inspect_model_cls(
|
|
architectures, model_path, revision, trust_remote_code, hf_config
|
|
).supports_pp
|
|
|
|
def model_has_inner_state(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
model_path: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
hf_config: Optional[object] = None,
|
|
) -> bool:
|
|
return self.inspect_model_cls(
|
|
architectures, model_path, revision, trust_remote_code, hf_config
|
|
).has_inner_state
|
|
|
|
def is_attention_free_model(
|
|
self,
|
|
architectures: Union[str, List[str]],
|
|
model_path: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
hf_config: Optional[object] = None,
|
|
) -> bool:
|
|
return self.inspect_model_cls(
|
|
architectures, model_path, revision, trust_remote_code, hf_config
|
|
).is_attention_free
|
|
|
|
|
|
ModelRegistry = _ModelRegistry({
|
|
model_arch: _LazyRegisteredModel(
|
|
module_name=f"vllm.model_executor.models.{mod_relname}",
|
|
class_name=cls_name,
|
|
)
|
|
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
|
|
})
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
|
|
# NOTE: We use a temporary directory instead of a temporary file to avoid
|
|
# issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
output_filepath = os.path.join(tempdir, "registry_output.tmp")
|
|
|
|
# `cloudpickle` allows pickling lambda functions directly
|
|
input_bytes = cloudpickle.dumps((fn, output_filepath))
|
|
|
|
# cannot use `sys.executable __file__` here because the script
|
|
# contains relative imports
|
|
returned = subprocess.run(
|
|
[sys.executable, "-m", "vllm.model_executor.models.registry"],
|
|
input=input_bytes,
|
|
capture_output=True)
|
|
|
|
# check if the subprocess is successful
|
|
try:
|
|
returned.check_returncode()
|
|
except Exception as e:
|
|
# wrap raised exception to provide more information
|
|
raise RuntimeError(f"Error raised in subprocess:\n"
|
|
f"{returned.stderr.decode()}") from e
|
|
|
|
with open(output_filepath, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
|
|
def _run() -> None:
|
|
# Setup plugins
|
|
from vllm.plugins import load_general_plugins
|
|
load_general_plugins()
|
|
|
|
fn, output_file = pickle.loads(sys.stdin.buffer.read())
|
|
|
|
result = fn()
|
|
|
|
with open(output_file, "wb") as f:
|
|
f.write(pickle.dumps(result))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_run() |