Sync from v0.13
This commit is contained in:
@@ -1,30 +1,150 @@
|
||||
from typing import Optional
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
||||
get_model_loader)
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||||
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
||||
from vllm.model_executor.model_loader.runai_streamer_loader import (
|
||||
RunaiModelStreamerLoader,
|
||||
)
|
||||
from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture)
|
||||
get_architecture_class_name,
|
||||
get_model_architecture,
|
||||
get_model_cls,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Reminder: Please update docstring in `LoadConfig`
|
||||
# if a new load format is added here
|
||||
LoadFormats = Literal[
|
||||
"auto",
|
||||
"hf",
|
||||
"bitsandbytes",
|
||||
"dummy",
|
||||
"fastsafetensors",
|
||||
"gguf",
|
||||
"mistral",
|
||||
"npcache",
|
||||
"pt",
|
||||
"runai_streamer",
|
||||
"runai_streamer_sharded",
|
||||
"safetensors",
|
||||
"sharded_state",
|
||||
"tensorizer",
|
||||
]
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
|
||||
"auto": DefaultModelLoader,
|
||||
"hf": DefaultModelLoader,
|
||||
"bitsandbytes": BitsAndBytesModelLoader,
|
||||
"dummy": DummyModelLoader,
|
||||
"fastsafetensors": DefaultModelLoader,
|
||||
"gguf": GGUFModelLoader,
|
||||
"mistral": DefaultModelLoader,
|
||||
"npcache": DefaultModelLoader,
|
||||
"pt": DefaultModelLoader,
|
||||
"runai_streamer": RunaiModelStreamerLoader,
|
||||
"runai_streamer_sharded": ShardedStateLoader,
|
||||
"safetensors": DefaultModelLoader,
|
||||
"sharded_state": ShardedStateLoader,
|
||||
"tensorizer": TensorizerLoader,
|
||||
}
|
||||
|
||||
|
||||
def register_model_loader(load_format: str):
|
||||
"""Register a customized vllm model loader.
|
||||
|
||||
When a load format is not supported by vllm, you can register a customized
|
||||
model loader to support it.
|
||||
|
||||
Args:
|
||||
load_format (str): The model loader format name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.config.load import LoadConfig
|
||||
>>> from vllm.model_executor.model_loader import (
|
||||
... get_model_loader,
|
||||
... register_model_loader,
|
||||
... )
|
||||
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
>>>
|
||||
>>> @register_model_loader("my_loader")
|
||||
... class MyModelLoader(BaseModelLoader):
|
||||
... def download_model(self):
|
||||
... pass
|
||||
...
|
||||
... def load_weights(self):
|
||||
... pass
|
||||
>>>
|
||||
>>> load_config = LoadConfig(load_format="my_loader")
|
||||
>>> type(get_model_loader(load_config))
|
||||
<class 'MyModelLoader'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(model_loader_cls):
|
||||
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
logger.warning(
|
||||
"Load format `%s` is already registered, and will be "
|
||||
"overwritten by the new loader class `%s`.",
|
||||
load_format,
|
||||
model_loader_cls,
|
||||
)
|
||||
if not issubclass(model_loader_cls, BaseModelLoader):
|
||||
raise ValueError(
|
||||
"The model loader must be a subclass of `BaseModelLoader`."
|
||||
)
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
|
||||
logger.info(
|
||||
"Registered model loader `%s` with load format `%s`",
|
||||
model_loader_cls,
|
||||
load_format,
|
||||
)
|
||||
return model_loader_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
load_format = load_config.load_format
|
||||
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
raise ValueError(f"Load format `{load_format}` is not supported")
|
||||
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
|
||||
|
||||
|
||||
def get_model(
|
||||
*, model_config: ModelConfig, load_config: LoadConfig,
|
||||
device_config: DeviceConfig, parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
||||
loader = get_model_loader(load_config)
|
||||
return loader.load_model(model_config=model_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config)
|
||||
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model", "get_model_loader", "BaseModelLoader",
|
||||
"get_architecture_class_name", "get_model_architecture"
|
||||
"get_model",
|
||||
"get_model_loader",
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
"get_model_cls",
|
||||
"register_model_loader",
|
||||
"BaseModelLoader",
|
||||
"BitsAndBytesModelLoader",
|
||||
"GGUFModelLoader",
|
||||
"DefaultModelLoader",
|
||||
"DummyModelLoader",
|
||||
"RunaiModelStreamerLoader",
|
||||
"ShardedStateLoader",
|
||||
"TensorizerLoader",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user