Sync from v0.13
This commit is contained in:
151
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
151
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig,
|
||||
deserialize_tensorizer_model,
|
||||
init_tensorizer_model,
|
||||
is_vllm_tensorized,
|
||||
serialize_vllm_model,
|
||||
tensorizer_weights_iterator,
|
||||
)
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
initialize_model,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BLACKLISTED_TENSORIZER_ARGS = {
|
||||
"device", # vLLM decides this
|
||||
"dtype", # vLLM decides this
|
||||
"mode", # Not meant to be configurable by the user
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config: dict):
|
||||
for k, v in config.items():
|
||||
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
|
||||
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
validate_config(load_config.model_loader_extra_config)
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config["tensorizer_config"]
|
||||
)
|
||||
|
||||
def _verify_config(
|
||||
self, model_config: ModelConfig, parallel_config: ParallelConfig
|
||||
):
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
This is only necessary when the model isn't vLLM-tensorized (see
|
||||
examples/others/tensorize_vllm_model.py) This should still
|
||||
be faster than default HuggingFace loading, but will be slower than
|
||||
loading a vLLM-tensorized model.
|
||||
"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
|
||||
with self.tensorizer_config.open_stream():
|
||||
pass
|
||||
|
||||
def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig:
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
return tensorizer_config
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load serialized model weights with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/others/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
deserialize_tensorizer_model(model, tensorizer_config)
|
||||
else:
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if parallel_config.tensor_parallel_size > 1:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
self.tensorizer_config.tensorizer_uri = (
|
||||
self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
|
||||
)
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
device_config = vllm_config.device_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = init_tensorizer_model(
|
||||
tensorizer_config=tensorizer_config, vllm_config=vllm_config
|
||||
)
|
||||
self.load_weights(model, model_config)
|
||||
return model
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
tensorizer_config: TensorizerConfig | dict,
|
||||
model_config: ModelConfig,
|
||||
) -> None:
|
||||
if isinstance(tensorizer_config, dict):
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config)
|
||||
serialize_vllm_model(
|
||||
model=model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
Reference in New Issue
Block a user