Files
bi_150-vllm/vllm/model_executor/model_loader/reload/layerwise.py

276 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
from functools import wraps
from weakref import WeakKeyDictionary
import torch
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .meta import (
capture_layer_to_meta,
get_numel_loaded,
materialize_layer,
restore_layer_on_meta,
)
from .types import LayerReloadingInfo
from .utils import get_layer_params_buffers, get_layer_size, get_layer_tensors
logger = init_logger(__name__)
__all__ = [
"get_layerwise_info",
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_reload",
]
# Global dict storing information used for layerwise restoring, loading, and processing.
# For more information regarding what info is stored when, see `LayerReloadingInfo`
#
# Use a weak ref dictionary so that modules can be freed when the model is freed.
# Values are sanitized from references to the layer key in order to avoid circular refs
LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
)
def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
"""
Get information related to restoring and layerwise processing. If no previous
information existed, a new entry is constructed
"""
if layer not in LAYERWISE_INFO:
LAYERWISE_INFO[layer] = LayerReloadingInfo()
return LAYERWISE_INFO[layer]
def record_metadata_for_reloading(model: torch.nn.Module):
"""
Record layer metadata needed for later reloading.
Stores parameter and buffer metadata as meta tensors for restoration.
Must be called before `initialize_layerwise_reload`.
"""
for layer in model.modules():
info = get_layerwise_info(layer)
info.restore_metadata = capture_layer_to_meta(layer)
@torch.no_grad()
def initialize_layerwise_reload(model: torch.nn.Module):
"""
Set up layerwise weight loading with deferred processing.
Must be called after `record_metadata_for_reloading`. This function:
1. Saves current kernel tensors for later copying
2. Restores layer parameters/buffers from metadata (on meta device)
3. Wraps weight loaders to defer processing until all weights are loaded
When all weights for a layer are loaded, the wrapped loaders will:
1. Materialize the layer onto the target device
2. Load all cached weights
3. Run quantization processing if applicable
4. Copy processed values back to original tensor storage
"""
# disable torchao reloading to avoid infinite recursion
model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
model._do_torchao_reload = False
for layer in model.modules():
info = get_layerwise_info(layer)
# Skip if the layer has already been initialized
if info.can_process():
continue
# Save current tensors for later copying
info.kernel_tensors = get_layer_params_buffers(layer)
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)
# Track loading progress to determine when to process/copy
info.load_numel = 0
info.load_numel_total = get_layer_size(layer)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for name, tensor in get_layer_tensors(layer).items():
if _get_weight_loader(tensor).__name__ != "online_process_loader":
tensor.weight_loader = make_online_process_loader(layer, name)
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
"""Create a wrapped weight loader that defers processing."""
info = get_layerwise_info(layer)
param = getattr(layer, param_name)
original_loader = _get_original_loader(param)
loader_signature = inspect.signature(original_loader)
@wraps(original_loader, assigned=("__doc__", "__annotations__"))
def online_process_loader(*args, **kwargs):
if not info.can_process():
# Unfortunately, some qconfigs are set up to load the same weight
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# each of the qkv partitions. This results in layers loading extra
# weights (beyond load_numel_total) after it's already processed.
#
# Best solution is to ensure that `load_numel_total` reflects the
# actual number of weights loaded, either by modifying qconfigs to
# create as many weights as loaded (see padding issue as well)
# or maybe capturing how many weights are loaded on first pass
#
# For now, `load_numel_total` is still safe to use as long as
# there's no way to reach `load_numel_total` without loading all
# necessary weights. `weight_shape` is very small, so this is safe.
# see Limitations(4)
logger.debug("%s: Excessive loading", layer.__class__.__name__)
return
# Bind and normalize arguments
bound_args = loader_signature.bind(*args, **kwargs)
bound_args.apply_defaults()
# Cache loaded weights, track loading progress
info.loaded_weights.append((param_name, bound_args))
num_loaded, ret = get_numel_loaded(original_loader, bound_args)
info.load_numel += num_loaded
logger.debug(
"%s: %d / %d",
layer.__class__.__name__,
info.load_numel,
info.load_numel_total,
)
# Process and copy when all weights are loaded
if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator]
layer, (Attention, MLAAttention)
):
_layerwise_process(layer, info)
return ret
return online_process_loader
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
"""
Remove the outermost layer of weight loading wrappers.
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
"""
model._do_torchao_reload = model._original_do_torchao_reload
for layer in model.modules():
info = get_layerwise_info(layer)
# Attention/MLA layers are processed after all other layers
if isinstance(layer, (Attention, MLAAttention)):
if info.load_numel > 0:
raise NotImplementedError(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
# No weights were loaded, place kernel tensors back
elif info.can_process() and info.load_numel <= 0:
_place_kernel_tensors(layer, info)
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to execess memory usage
# see Limitations(4)
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
logger.debug("%s: Delayed processing", layer.__class__.__name__)
_layerwise_process(layer, info)
info.reset()
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
"""
Finalize layer loading after all weights have been cached.
This function:
1. Materializes the layer onto the target device
2. Loads all cached weights
3. Runs quantization processing if applicable
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer(layer)
# Reset FP8 online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):
delattr(layer, "_already_called_process_weights_after_loading")
# Unwrap layerwise loading wrappers
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
# Load all cached weights into materialized layer (using original loaders)
for name, args in info.loaded_weights:
param = getattr(layer, name)
args.arguments["param"] = param
param.weight_loader(*args.args, **args.kwargs)
# Process weights (quantization, repacking, etc.)
# Attention/MLA are processed in `finalize_layerwise_reload`
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
_place_kernel_tensors(layer, info)
info.reset()
logger.debug("%s: Processed", layer.__class__.__name__)
def _get_original_loader(tensor: torch.Tensor) -> Callable:
"""Return the weight loader with any layerwise wrappers removed"""
loader = _get_weight_loader(tensor)
while loader.__name__ == "online_process_loader":
loader = loader.__wrapped__ # type: ignore[union-attr]
return loader
def _get_weight_loader(tensor: torch.Tensor):
return getattr(tensor, "weight_loader", default_weight_loader)
def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for name in get_layer_tensors(layer):
delattr(layer, name)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
layer.register_parameter(name, param)
for name, buffer in buffers.items():
layer.register_buffer(name, buffer)