init
This commit is contained in:
119
vllm/model_executor/models/__init__.py
Executable file
119
vllm/model_executor/models/__init__.py
Executable file
@@ -0,0 +1,119 @@
|
||||
import importlib
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Architecture -> (module, class).
|
||||
_MODELS = {
|
||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
||||
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"LlavaForConditionalGeneration":
|
||||
("llava", "LlavaForConditionalGeneration"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||
}
|
||||
|
||||
# Architecture -> type.
|
||||
# out of tree models
|
||||
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS = []
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
||||
"Qwen2ForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
"MistralForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
"MixtralForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
}
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
|
||||
@staticmethod
|
||||
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch in _OOT_MODELS:
|
||||
return _OOT_MODELS[model_arch]
|
||||
if model_arch not in _MODELS:
|
||||
return None
|
||||
if is_hip():
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model architecture {model_arch} is not supported by "
|
||||
"ROCm for now.")
|
||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
logger.warning(
|
||||
"Model architecture %s is partially supported by ROCm: %s",
|
||||
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||
|
||||
module_name, model_cls_name = _MODELS[model_arch]
|
||||
module = importlib.import_module(
|
||||
f"vllm.model_executor.models.{module_name}")
|
||||
return getattr(module, model_cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_archs() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
@staticmethod
|
||||
def register_model(model_arch: str, model_cls: Type[nn.Module]):
|
||||
if model_arch in _MODELS:
|
||||
logger.warning(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls.__name__)
|
||||
global _OOT_MODELS
|
||||
_OOT_MODELS[model_arch] = model_cls
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModelRegistry",
|
||||
]
|
||||
Reference in New Issue
Block a user