update
This commit is contained in:
37
vllm/model_executor/model_loader/reload/__init__.py
Normal file
37
vllm/model_executor/model_loader/reload/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Layerwise weight reloading utilities for vLLM.
|
||||
|
||||
This module provides functionality to reload model weights layer-by-layer,
|
||||
which is useful for weight updates without full model reconstruction.
|
||||
|
||||
Limitations:
|
||||
1. Composition with CPU offloading has not been implemented
|
||||
2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented
|
||||
3. Tied parameters will only reflect processing from one of the parent layers (for
|
||||
example, only processing from embed_tokens will have an effect)
|
||||
4. This design assumes that the number of weights loaded from disk is the same as the
|
||||
number of weights created at model init time. This is not true for quant methods
|
||||
which (1) pad weights or (2) load qkv weights into the same parameter. Both of these
|
||||
cases are non-issues for today's quant methods, but future quantizations may cause
|
||||
reloading to fail
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"record_metadata_for_reloading",
|
||||
"initialize_layerwise_reload",
|
||||
"finalize_layerwise_reload",
|
||||
"set_torchao_reload_attrs",
|
||||
"support_quantized_model_reload_from_hp_weights",
|
||||
]
|
||||
|
||||
from .layerwise import (
|
||||
finalize_layerwise_reload,
|
||||
initialize_layerwise_reload,
|
||||
record_metadata_for_reloading,
|
||||
)
|
||||
from .torchao_decorator import (
|
||||
set_torchao_reload_attrs,
|
||||
support_quantized_model_reload_from_hp_weights,
|
||||
)
|
||||
275
vllm/model_executor/model_loader/reload/layerwise.py
Normal file
275
vllm/model_executor/model_loader/reload/layerwise.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .meta import (
|
||||
capture_layer_to_meta,
|
||||
get_numel_loaded,
|
||||
materialize_layer,
|
||||
restore_layer_on_meta,
|
||||
)
|
||||
from .types import LayerReloadingInfo
|
||||
from .utils import get_layer_params_buffers, get_layer_size, get_layer_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"get_layerwise_info",
|
||||
"record_metadata_for_reloading",
|
||||
"initialize_layerwise_reload",
|
||||
"finalize_layerwise_reload",
|
||||
]
|
||||
|
||||
|
||||
# Global dict storing information used for layerwise restoring, loading, and processing.
|
||||
# For more information regarding what info is stored when, see `LayerReloadingInfo`
|
||||
#
|
||||
# Use a weak ref dictionary so that modules can be freed when the model is freed.
|
||||
# Values are sanitized from references to the layer key in order to avoid circular refs
|
||||
LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
|
||||
WeakKeyDictionary()
|
||||
)
|
||||
|
||||
|
||||
def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
|
||||
"""
|
||||
Get information related to restoring and layerwise processing. If no previous
|
||||
information existed, a new entry is constructed
|
||||
"""
|
||||
if layer not in LAYERWISE_INFO:
|
||||
LAYERWISE_INFO[layer] = LayerReloadingInfo()
|
||||
|
||||
return LAYERWISE_INFO[layer]
|
||||
|
||||
|
||||
def record_metadata_for_reloading(model: torch.nn.Module):
|
||||
"""
|
||||
Record layer metadata needed for later reloading.
|
||||
|
||||
Stores parameter and buffer metadata as meta tensors for restoration.
|
||||
Must be called before `initialize_layerwise_reload`.
|
||||
"""
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
info.restore_metadata = capture_layer_to_meta(layer)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_layerwise_reload(model: torch.nn.Module):
|
||||
"""
|
||||
Set up layerwise weight loading with deferred processing.
|
||||
|
||||
Must be called after `record_metadata_for_reloading`. This function:
|
||||
1. Saves current kernel tensors for later copying
|
||||
2. Restores layer parameters/buffers from metadata (on meta device)
|
||||
3. Wraps weight loaders to defer processing until all weights are loaded
|
||||
|
||||
When all weights for a layer are loaded, the wrapped loaders will:
|
||||
1. Materialize the layer onto the target device
|
||||
2. Load all cached weights
|
||||
3. Run quantization processing if applicable
|
||||
4. Copy processed values back to original tensor storage
|
||||
"""
|
||||
# disable torchao reloading to avoid infinite recursion
|
||||
model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
|
||||
model._do_torchao_reload = False
|
||||
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
|
||||
# Skip if the layer has already been initialized
|
||||
if info.can_process():
|
||||
continue
|
||||
|
||||
# Save current tensors for later copying
|
||||
info.kernel_tensors = get_layer_params_buffers(layer)
|
||||
|
||||
# Restore layer parameters/buffers onto meta device
|
||||
restore_layer_on_meta(layer, info)
|
||||
|
||||
# Track loading progress to determine when to process/copy
|
||||
info.load_numel = 0
|
||||
info.load_numel_total = get_layer_size(layer)
|
||||
|
||||
# Wrap each parameter's weight loader
|
||||
# Note that nested wrapping will occur for shared tensors
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if _get_weight_loader(tensor).__name__ != "online_process_loader":
|
||||
tensor.weight_loader = make_online_process_loader(layer, name)
|
||||
|
||||
|
||||
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
|
||||
"""Create a wrapped weight loader that defers processing."""
|
||||
info = get_layerwise_info(layer)
|
||||
param = getattr(layer, param_name)
|
||||
original_loader = _get_original_loader(param)
|
||||
loader_signature = inspect.signature(original_loader)
|
||||
|
||||
@wraps(original_loader, assigned=("__doc__", "__annotations__"))
|
||||
def online_process_loader(*args, **kwargs):
|
||||
if not info.can_process():
|
||||
# Unfortunately, some qconfigs are set up to load the same weight
|
||||
# multiple times. For example, CT_WNA16 loads `weight_shape` for
|
||||
# each of the qkv partitions. This results in layers loading extra
|
||||
# weights (beyond load_numel_total) after it's already processed.
|
||||
#
|
||||
# Best solution is to ensure that `load_numel_total` reflects the
|
||||
# actual number of weights loaded, either by modifying qconfigs to
|
||||
# create as many weights as loaded (see padding issue as well)
|
||||
# or maybe capturing how many weights are loaded on first pass
|
||||
#
|
||||
# For now, `load_numel_total` is still safe to use as long as
|
||||
# there's no way to reach `load_numel_total` without loading all
|
||||
# necessary weights. `weight_shape` is very small, so this is safe.
|
||||
# see Limitations(4)
|
||||
logger.debug("%s: Excessive loading", layer.__class__.__name__)
|
||||
return
|
||||
|
||||
# Bind and normalize arguments
|
||||
bound_args = loader_signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
# Cache loaded weights, track loading progress
|
||||
info.loaded_weights.append((param_name, bound_args))
|
||||
num_loaded, ret = get_numel_loaded(original_loader, bound_args)
|
||||
info.load_numel += num_loaded
|
||||
|
||||
logger.debug(
|
||||
"%s: %d / %d",
|
||||
layer.__class__.__name__,
|
||||
info.load_numel,
|
||||
info.load_numel_total,
|
||||
)
|
||||
|
||||
# Process and copy when all weights are loaded
|
||||
if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator]
|
||||
layer, (Attention, MLAAttention)
|
||||
):
|
||||
_layerwise_process(layer, info)
|
||||
|
||||
return ret
|
||||
|
||||
return online_process_loader
|
||||
|
||||
|
||||
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
|
||||
"""
|
||||
Remove the outermost layer of weight loading wrappers.
|
||||
|
||||
This function should be applied after `initialize_layerwise_reload` is applied
|
||||
unwrap the layerwise weight loaders.
|
||||
|
||||
Also processes Attention/MLA layers, which must be processed after all other layers
|
||||
"""
|
||||
model._do_torchao_reload = model._original_do_torchao_reload
|
||||
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
|
||||
# Attention/MLA layers are processed after all other layers
|
||||
if isinstance(layer, (Attention, MLAAttention)):
|
||||
if info.load_numel > 0:
|
||||
raise NotImplementedError(
|
||||
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
|
||||
)
|
||||
|
||||
else:
|
||||
_place_kernel_tensors(layer, info)
|
||||
layer.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
# No weights were loaded, place kernel tensors back
|
||||
elif info.can_process() and info.load_numel <= 0:
|
||||
_place_kernel_tensors(layer, info)
|
||||
|
||||
# Process non-attention layers which did not load all elements. This can happen
|
||||
# if the created weight has extra padding elements which are not loaded
|
||||
# Having too many of these delayed layers can lead to execess memory usage
|
||||
# see Limitations(4)
|
||||
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
|
||||
logger.debug("%s: Delayed processing", layer.__class__.__name__)
|
||||
_layerwise_process(layer, info)
|
||||
|
||||
info.reset()
|
||||
|
||||
|
||||
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
"""
|
||||
Finalize layer loading after all weights have been cached.
|
||||
|
||||
This function:
|
||||
1. Materializes the layer onto the target device
|
||||
2. Loads all cached weights
|
||||
3. Runs quantization processing if applicable
|
||||
4. Copies processed values back to original tensor storage
|
||||
"""
|
||||
# Materialize layer tensors onto device
|
||||
materialize_layer(layer)
|
||||
|
||||
# Reset FP8 online quantization flag so process_weights_after_loading
|
||||
# will run again during reload
|
||||
if hasattr(layer, "_already_called_process_weights_after_loading"):
|
||||
delattr(layer, "_already_called_process_weights_after_loading")
|
||||
|
||||
# Unwrap layerwise loading wrappers
|
||||
for param in get_layer_tensors(layer).values():
|
||||
param.weight_loader = _get_original_loader(param)
|
||||
|
||||
# Load all cached weights into materialized layer (using original loaders)
|
||||
for name, args in info.loaded_weights:
|
||||
param = getattr(layer, name)
|
||||
args.arguments["param"] = param
|
||||
param.weight_loader(*args.args, **args.kwargs)
|
||||
|
||||
# Process weights (quantization, repacking, etc.)
|
||||
# Attention/MLA are processed in `finalize_layerwise_reload`
|
||||
quant_method = getattr(layer, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
quant_method.process_weights_after_loading(layer)
|
||||
|
||||
# Copy processed values into original tensor storage (preserves cudagraph refs)
|
||||
# this code is a no-op if not reloading (because kernel tensors is empty)
|
||||
parameters, buffers = info.kernel_tensors
|
||||
for name, param in parameters.items():
|
||||
param.data.copy_(getattr(layer, name))
|
||||
for name, buffer in buffers.items():
|
||||
buffer.data.copy_(getattr(layer, name))
|
||||
|
||||
_place_kernel_tensors(layer, info)
|
||||
|
||||
info.reset()
|
||||
logger.debug("%s: Processed", layer.__class__.__name__)
|
||||
|
||||
|
||||
def _get_original_loader(tensor: torch.Tensor) -> Callable:
|
||||
"""Return the weight loader with any layerwise wrappers removed"""
|
||||
loader = _get_weight_loader(tensor)
|
||||
while loader.__name__ == "online_process_loader":
|
||||
loader = loader.__wrapped__ # type: ignore[union-attr]
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def _get_weight_loader(tensor: torch.Tensor):
|
||||
return getattr(tensor, "weight_loader", default_weight_loader)
|
||||
|
||||
|
||||
def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
for name in get_layer_tensors(layer):
|
||||
delattr(layer, name)
|
||||
|
||||
parameters, buffers = info.kernel_tensors
|
||||
for name, param in parameters.items():
|
||||
layer.register_parameter(name, param)
|
||||
for name, buffer in buffers.items():
|
||||
layer.register_buffer(name, buffer)
|
||||
146
vllm/model_executor/model_loader/reload/meta.py
Normal file
146
vllm/model_executor/model_loader/reload/meta.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
from .sanitize import restore_layer_refs, sanitize_layer_refs
|
||||
from .types import LayerReloadingInfo, LayerTensors
|
||||
from .utils import get_layer_params_buffers, get_layer_tensors
|
||||
|
||||
__all__ = [
|
||||
"to_meta_tensor",
|
||||
"materialize_meta_tensor",
|
||||
"capture_layer_to_meta",
|
||||
"restore_layer_on_meta",
|
||||
"materialize_layer",
|
||||
"get_numel_loaded",
|
||||
]
|
||||
|
||||
SKIP_MODULES: set[str] = {"HadamardTransform"}
|
||||
|
||||
SKIP_TENSORS: set[str] = {
|
||||
"_expert_map",
|
||||
"expert_mask",
|
||||
"expert_global_to_physical",
|
||||
"expert_physical_to_global",
|
||||
"expert_local_to_global",
|
||||
}
|
||||
|
||||
|
||||
def to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a tensor to a meta tensor while preserving class and attributes."""
|
||||
meta_tensor = tensor.data.to("meta")
|
||||
meta_tensor.__class__ = tensor.__class__
|
||||
meta_tensor.__dict__ = tensor.__dict__.copy()
|
||||
return meta_tensor
|
||||
|
||||
|
||||
def materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Materialize a meta tensor into an actual tensor on the current device.
|
||||
Should be called within the torch device context for the given rank.
|
||||
"""
|
||||
tensor = torch.empty_strided(
|
||||
size=tuple(meta_tensor.size()),
|
||||
stride=tuple(meta_tensor.stride()),
|
||||
dtype=meta_tensor.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
tensor.__class__ = meta_tensor.__class__
|
||||
tensor.__dict__ = meta_tensor.__dict__.copy()
|
||||
return tensor
|
||||
|
||||
|
||||
def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors:
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return ({}, {})
|
||||
|
||||
params, buffers = get_layer_params_buffers(layer)
|
||||
return (
|
||||
{
|
||||
name: sanitize_layer_refs(to_meta_tensor(param), layer)
|
||||
for name, param in params.items()
|
||||
if name not in SKIP_TENSORS
|
||||
},
|
||||
{
|
||||
name: sanitize_layer_refs(to_meta_tensor(buffer), layer)
|
||||
for name, buffer in buffers.items()
|
||||
if name not in SKIP_TENSORS
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
"""Restore a layer to model format with tensors on the meta device"""
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return
|
||||
|
||||
for name in get_layer_tensors(layer):
|
||||
if name not in SKIP_TENSORS:
|
||||
delattr(layer, name)
|
||||
|
||||
restore_params, restore_buffers = info.restore_metadata
|
||||
for name, param in restore_params.items():
|
||||
if name not in SKIP_TENSORS:
|
||||
param = restore_layer_refs(param, layer)
|
||||
layer.register_parameter(name, param)
|
||||
|
||||
for name, buffer in restore_buffers.items():
|
||||
if name not in SKIP_TENSORS:
|
||||
buffer = restore_layer_refs(buffer, layer)
|
||||
layer.register_buffer(name, buffer)
|
||||
|
||||
|
||||
def materialize_layer(layer: torch.nn.Module) -> None:
|
||||
"""Materialize all meta tensors in a layer to actual tensors."""
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return
|
||||
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if name not in SKIP_TENSORS:
|
||||
setattr(layer, name, materialize_meta_tensor(tensor))
|
||||
|
||||
|
||||
class MetaCopyCounter(TorchDispatchMode):
|
||||
"""
|
||||
Tracks total number of elements modified with `copy_`.
|
||||
|
||||
Useful for keeping track of weight loading where underlying weights can be
|
||||
arbitrarily transformed (such as with `narrow`) before calling copy.
|
||||
|
||||
Note: Assumes that copy kwargs are not used.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.copied_numel = 0
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.ops.aten.copy_.default and args[0].device.type == "meta":
|
||||
assert args[0].numel() == args[1].numel()
|
||||
self.copied_numel += args[0].numel()
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def get_numel_loaded(
|
||||
weight_loader: Callable, args: inspect.BoundArguments
|
||||
) -> tuple[int, object]:
|
||||
"""
|
||||
Determine how many elements would be loaded by a weight loader call.
|
||||
|
||||
:param weight loader: used to load weights
|
||||
:param args: bound arguments to weight loader
|
||||
:return: number of elements loaded by the weight loader, the return value of the
|
||||
weight loader
|
||||
"""
|
||||
assert args.arguments["param"].device.type == "meta"
|
||||
with MetaCopyCounter() as counter:
|
||||
return_value = weight_loader(*args.args, **args.kwargs)
|
||||
return counter.copied_numel, return_value
|
||||
50
vllm/model_executor/model_loader/reload/sanitize.py
Normal file
50
vllm/model_executor/model_loader/reload/sanitize.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["sanitize_layer_refs", "restore_layer_refs"]
|
||||
|
||||
|
||||
layer_ref_sentinel = object()
|
||||
|
||||
|
||||
def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
|
||||
"""
|
||||
Removes references to layer held by tensor attributes. Specifically, removes the
|
||||
`__self__` attribute of weight loader methods attached to the tensor.
|
||||
|
||||
Used by `capture_layer_to_meta` to avoid circular references to layers in
|
||||
`LAYERWISE_INFO`, leading to modules never being cleaned up. Without sanitation,
|
||||
tensors will reference layers, and the WeakKeyDictionary will never evict entries,
|
||||
even when the model is deleted.
|
||||
|
||||
:param tensor: tensor to be sanitized
|
||||
:param layer: layer whose references should be removed
|
||||
:return: sanitized tensor
|
||||
"""
|
||||
for key, value in tensor.__dict__.items():
|
||||
if isinstance(value, MethodType) and value.__self__ is layer:
|
||||
tensor.__dict__[key] = value.__func__.__get__(layer_ref_sentinel)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
|
||||
"""
|
||||
Restores references to layer held by tensor attributes.
|
||||
|
||||
Used by `restore_layer_on_meta` to add back layer references, allowing for proper
|
||||
weight loading.
|
||||
|
||||
:param tensor: tensor to be sanitized
|
||||
:param layer: layer whose references should be removed
|
||||
:return: sanitized tensor
|
||||
|
||||
"""
|
||||
for key, value in tensor.__dict__.items():
|
||||
if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel:
|
||||
tensor.__dict__[key] = value.__func__.__get__(layer)
|
||||
|
||||
return tensor
|
||||
58
vllm/model_executor/model_loader/reload/torchao_decorator.py
Normal file
58
vllm/model_executor/model_loader/reload/torchao_decorator.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import wraps
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
from .layerwise import (
|
||||
finalize_layerwise_reload,
|
||||
initialize_layerwise_reload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
__all__ = ["set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights"]
|
||||
|
||||
|
||||
def set_torchao_reload_attrs(model: torch.nn.Module, model_config: ModelConfig):
|
||||
model._do_torchao_reload = True
|
||||
model._model_config = model_config
|
||||
|
||||
|
||||
def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType):
|
||||
"""
|
||||
Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
|
||||
reloading high precision (bfloat16/float16/float32) weight for an already quantized
|
||||
model, this involves restoring the weights to a high precision weights and
|
||||
then online quantize the weights.
|
||||
|
||||
Only applies to torchao quantized models. Assumes that all model weights are
|
||||
loaded within a single weights iterator (cannot perform batched updates)
|
||||
"""
|
||||
|
||||
@wraps(original_load_weights)
|
||||
def patched_model_load_weights(
|
||||
self: "AutoWeightsLoader",
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
model = self.module
|
||||
|
||||
if not getattr(model, "_do_torchao_reload", False):
|
||||
return original_load_weights(self, weights, *args, **kwargs)
|
||||
|
||||
initialize_layerwise_reload(model)
|
||||
loaded_weights = original_load_weights(self, weights, *args, **kwargs)
|
||||
finalize_layerwise_reload(model, model._model_config)
|
||||
|
||||
return loaded_weights
|
||||
|
||||
return patched_model_load_weights
|
||||
33
vllm/model_executor/model_loader/reload/types.py
Normal file
33
vllm/model_executor/model_loader/reload/types.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import BoundArguments
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["LayerTensors", "LayerReloadingInfo"]
|
||||
|
||||
# encodes both parameters and buffers separately
|
||||
LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerReloadingInfo:
|
||||
# model format (meta), populated by `record_metadata_for_reloading`
|
||||
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
|
||||
|
||||
# kernel format (device)
|
||||
kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {}))
|
||||
|
||||
# track how many restored elements are ready for loading
|
||||
load_numel: int = 0
|
||||
load_numel_total: int | None = None
|
||||
|
||||
# stores arguments and tensors ready for loading
|
||||
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)
|
||||
|
||||
def reset(self):
|
||||
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
|
||||
|
||||
def can_process(self) -> bool:
|
||||
return self.load_numel_total is not None
|
||||
31
vllm/model_executor/model_loader/reload/utils.py
Normal file
31
vllm/model_executor/model_loader/reload/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from .types import LayerTensors
|
||||
|
||||
__all__ = [
|
||||
"get_layer_tensors",
|
||||
"get_layer_params_buffers",
|
||||
"get_layer_size",
|
||||
]
|
||||
|
||||
|
||||
def get_layer_tensors(layer: torch.nn.Module) -> dict[str, torch.Tensor]:
|
||||
"""Get all parameters and buffers from a module as a dict."""
|
||||
params, buffers = get_layer_params_buffers(layer)
|
||||
return params | buffers
|
||||
|
||||
|
||||
def get_layer_params_buffers(layer: torch.nn.Module) -> LayerTensors:
|
||||
"""Get all parameters and buffers of a module as a tuple of dicts."""
|
||||
return (
|
||||
{name: param for name, param in layer._parameters.items() if param is not None},
|
||||
{name: buffer for name, buffer in layer._buffers.items() if buffer is not None},
|
||||
)
|
||||
|
||||
|
||||
def get_layer_size(layer: torch.nn.Module) -> int:
|
||||
"""Calculate total number of elements across all tensors in a layer."""
|
||||
return sum(tensor.numel() for tensor in get_layer_tensors(layer).values())
|
||||
Reference in New Issue
Block a user