# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: SIM117 import copy from collections.abc import Generator import torch from torch import nn from vllm.config import ModelConfig, ParallelConfig, VllmConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model, is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator, ) from vllm.model_executor.model_loader.utils import ( get_model_architecture, initialize_model, ) from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) BLACKLISTED_TENSORIZER_ARGS = { "device", # vLLM decides this "dtype", # vLLM decides this "mode", # Not meant to be configurable by the user } def validate_config(config: dict): for k, v in config.items(): if v is not None and k in BLACKLISTED_TENSORIZER_ARGS: raise ValueError(f"{k} is not an allowed Tensorizer argument.") class TensorizerLoader(BaseModelLoader): """Model loader using CoreWeave's tensorizer library.""" def __init__(self, load_config: LoadConfig): super().__init__(load_config) if isinstance(load_config.model_loader_extra_config, TensorizerConfig): self.tensorizer_config = load_config.model_loader_extra_config else: validate_config(load_config.model_loader_extra_config) self.tensorizer_config = TensorizerConfig( **load_config.model_loader_extra_config["tensorizer_config"] ) def _verify_config( self, model_config: ModelConfig, parallel_config: ParallelConfig ): self.tensorizer_config.verify_with_model_config(model_config) self.tensorizer_config.verify_with_parallel_config(parallel_config) def _get_weights_iterator( self, ) -> Generator[tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) def _load_model_serialized_cpu( self, vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. This is only necessary when the model isn't vLLM-tensorized (see examples/others/tensorize_vllm_model.py) This should still be faster than default HuggingFace loading, but will be slower than loading a vLLM-tensorized model. """ device_config = vllm_config.device_config model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = initialize_model(vllm_config=vllm_config) model.load_weights(self._get_weights_iterator()) return model.eval() def download_model(self, model_config: ModelConfig) -> None: self.tensorizer_config.verify_with_model_config(model_config) with self.tensorizer_config.open_stream(): pass def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig: model_class = get_model_architecture(model_config)[0] tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class tensorizer_config.hf_config = model_config.hf_config tensorizer_config.dtype = model_config.dtype return tensorizer_config def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: """Load serialized model weights with tensorizer. Expects a vLLM-tensorized model. See the examples/others/tensorize_vllm_model.py example script for serializing vLLM models.""" if is_vllm_tensorized(self.tensorizer_config): tensorizer_config = self._patch_tensorizer_config(model_config) deserialize_tensorizer_model(model, tensorizer_config) else: model.load_weights(self._get_weights_iterator()) def load_model( self, vllm_config: VllmConfig, model_config: ModelConfig ) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) if parallel_config.tensor_parallel_size > 1: from vllm.distributed import get_tensor_model_parallel_rank self.tensorizer_config.tensorizer_uri = ( self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank() ) if is_vllm_tensorized(self.tensorizer_config): tensorizer_config = self._patch_tensorizer_config(model_config) device_config = vllm_config.device_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = init_tensorizer_model( tensorizer_config=tensorizer_config, vllm_config=vllm_config ) self.load_weights(model, model_config) return model return self._load_model_serialized_cpu(vllm_config=vllm_config) @staticmethod def save_model( model: torch.nn.Module, tensorizer_config: TensorizerConfig | dict, model_config: ModelConfig, ) -> None: if isinstance(tensorizer_config, dict): tensorizer_config = TensorizerConfig(**tensorizer_config) serialize_vllm_model( model=model, tensorizer_config=tensorizer_config, model_config=model_config, )