31 lines
1.3 KiB
Python
31 lines
1.3 KiB
Python
from typing import Optional
|
|
|
|
from torch import nn
|
|
|
|
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
|
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
|
get_model_loader)
|
|
from vllm.model_executor.model_loader.utils import (
|
|
get_architecture_class_name, get_model_architecture)
|
|
|
|
|
|
def get_model(
|
|
*, model_config: ModelConfig, load_config: LoadConfig,
|
|
device_config: DeviceConfig, parallel_config: ParallelConfig,
|
|
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
|
|
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
|
loader = get_model_loader(load_config)
|
|
return loader.load_model(model_config=model_config,
|
|
device_config=device_config,
|
|
lora_config=lora_config,
|
|
vision_language_config=vision_language_config,
|
|
parallel_config=parallel_config,
|
|
scheduler_config=scheduler_config)
|
|
|
|
|
|
__all__ = [
|
|
"get_model", "get_model_loader", "BaseModelLoader",
|
|
"get_architecture_class_name", "get_model_architecture"
|
|
]
|