# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: SIM117 import os from collections.abc import Generator import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, runai_safetensors_weights_iterator, ) from vllm.transformers_utils.runai_utils import is_runai_obj_uri, list_safetensors class RunaiModelStreamerLoader(BaseModelLoader): """ Model loader that can load safetensors files from local FS or S3 bucket. """ def __init__(self, load_config: LoadConfig): super().__init__(load_config) self._is_distributed = False if load_config.model_loader_extra_config: extra_config = load_config.model_loader_extra_config if "distributed" in extra_config and isinstance( extra_config.get("distributed"), bool ): self._is_distributed = extra_config.get("distributed") if "concurrency" in extra_config and isinstance( extra_config.get("concurrency"), int ): os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( extra_config.get("concurrency") ) if "memory_limit" in extra_config and isinstance( extra_config.get("memory_limit"), int ): os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( extra_config.get("memory_limit") ) runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT") aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL") if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None: os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url def _prepare_weights( self, model_name_or_path: str, revision: str | None ) -> list[str]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" is_object_storage_path = is_runai_obj_uri(model_name_or_path) is_local = os.path.isdir(model_name_or_path) safetensors_pattern = "*.safetensors" index_file = SAFE_WEIGHTS_INDEX_NAME hf_folder = ( model_name_or_path if (is_local or is_object_storage_path) else download_weights_from_hf( model_name_or_path, self.load_config.download_dir, [safetensors_pattern], revision, ignore_patterns=self.load_config.ignore_patterns, ) ) hf_weights_files = list_safetensors(path=hf_folder) if not is_local and not is_object_storage_path: download_safetensors_index_file_from_hf( model_name_or_path, index_file, self.load_config.download_dir, revision ) if not hf_weights_files: raise RuntimeError( f"Cannot find any safetensors model weights with `{model_name_or_path}`" ) return hf_weights_files def _get_weights_iterator( self, model_or_path: str, revision: str ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_weights_files = self._prepare_weights(model_or_path, revision) return runai_safetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed ) def download_model(self, model_config: ModelConfig) -> None: """Download model if necessary""" self._prepare_weights(model_config.model, model_config.revision) def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: """Load weights into a model.""" model_weights = model_config.model if hasattr(model_config, "model_weights"): model_weights = model_config.model_weights model.load_weights( self._get_weights_iterator(model_weights, model_config.revision) )