Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layerwise weight reloading utilities for vLLM.
This module provides functionality to reload model weights layer-by-layer,
which is useful for weight updates without full model reconstruction.
Limitations:
1. Composition with CPU offloading has not been implemented
2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented
3. Tied parameters will only reflect processing from one of the parent layers (for
example, only processing from embed_tokens will have an effect)
4. This design assumes that the number of weights loaded from disk is the same as the
number of weights created at model init time. This is not true for quant methods
which (1) pad weights or (2) load qkv weights into the same parameter. Both of these
cases are non-issues for today's quant methods, but future quantizations may cause
reloading to fail
"""
__all__ = [
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_reload",
"set_torchao_reload_attrs",
"support_quantized_model_reload_from_hp_weights",
]
from .layerwise import (
finalize_layerwise_reload,
initialize_layerwise_reload,
record_metadata_for_reloading,
)
from .torchao_decorator import (
set_torchao_reload_attrs,
support_quantized_model_reload_from_hp_weights,
)

View File

@@ -0,0 +1,275 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
from functools import wraps
from weakref import WeakKeyDictionary
import torch
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .meta import (
capture_layer_to_meta,
get_numel_loaded,
materialize_layer,
restore_layer_on_meta,
)
from .types import LayerReloadingInfo
from .utils import get_layer_params_buffers, get_layer_size, get_layer_tensors
logger = init_logger(__name__)
__all__ = [
"get_layerwise_info",
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_reload",
]
# Global dict storing information used for layerwise restoring, loading, and processing.
# For more information regarding what info is stored when, see `LayerReloadingInfo`
#
# Use a weak ref dictionary so that modules can be freed when the model is freed.
# Values are sanitized from references to the layer key in order to avoid circular refs
LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
)
def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
"""
Get information related to restoring and layerwise processing. If no previous
information existed, a new entry is constructed
"""
if layer not in LAYERWISE_INFO:
LAYERWISE_INFO[layer] = LayerReloadingInfo()
return LAYERWISE_INFO[layer]
def record_metadata_for_reloading(model: torch.nn.Module):
"""
Record layer metadata needed for later reloading.
Stores parameter and buffer metadata as meta tensors for restoration.
Must be called before `initialize_layerwise_reload`.
"""
for layer in model.modules():
info = get_layerwise_info(layer)
info.restore_metadata = capture_layer_to_meta(layer)
@torch.no_grad()
def initialize_layerwise_reload(model: torch.nn.Module):
"""
Set up layerwise weight loading with deferred processing.
Must be called after `record_metadata_for_reloading`. This function:
1. Saves current kernel tensors for later copying
2. Restores layer parameters/buffers from metadata (on meta device)
3. Wraps weight loaders to defer processing until all weights are loaded
When all weights for a layer are loaded, the wrapped loaders will:
1. Materialize the layer onto the target device
2. Load all cached weights
3. Run quantization processing if applicable
4. Copy processed values back to original tensor storage
"""
# disable torchao reloading to avoid infinite recursion
model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
model._do_torchao_reload = False
for layer in model.modules():
info = get_layerwise_info(layer)
# Skip if the layer has already been initialized
if info.can_process():
continue
# Save current tensors for later copying
info.kernel_tensors = get_layer_params_buffers(layer)
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)
# Track loading progress to determine when to process/copy
info.load_numel = 0
info.load_numel_total = get_layer_size(layer)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for name, tensor in get_layer_tensors(layer).items():
if _get_weight_loader(tensor).__name__ != "online_process_loader":
tensor.weight_loader = make_online_process_loader(layer, name)
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
"""Create a wrapped weight loader that defers processing."""
info = get_layerwise_info(layer)
param = getattr(layer, param_name)
original_loader = _get_original_loader(param)
loader_signature = inspect.signature(original_loader)
@wraps(original_loader, assigned=("__doc__", "__annotations__"))
def online_process_loader(*args, **kwargs):
if not info.can_process():
# Unfortunately, some qconfigs are set up to load the same weight
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# each of the qkv partitions. This results in layers loading extra
# weights (beyond load_numel_total) after it's already processed.
#
# Best solution is to ensure that `load_numel_total` reflects the
# actual number of weights loaded, either by modifying qconfigs to
# create as many weights as loaded (see padding issue as well)
# or maybe capturing how many weights are loaded on first pass
#
# For now, `load_numel_total` is still safe to use as long as
# there's no way to reach `load_numel_total` without loading all
# necessary weights. `weight_shape` is very small, so this is safe.
# see Limitations(4)
logger.debug("%s: Excessive loading", layer.__class__.__name__)
return
# Bind and normalize arguments
bound_args = loader_signature.bind(*args, **kwargs)
bound_args.apply_defaults()
# Cache loaded weights, track loading progress
info.loaded_weights.append((param_name, bound_args))
num_loaded, ret = get_numel_loaded(original_loader, bound_args)
info.load_numel += num_loaded
logger.debug(
"%s: %d / %d",
layer.__class__.__name__,
info.load_numel,
info.load_numel_total,
)
# Process and copy when all weights are loaded
if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator]
layer, (Attention, MLAAttention)
):
_layerwise_process(layer, info)
return ret
return online_process_loader
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
"""
Remove the outermost layer of weight loading wrappers.
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
"""
model._do_torchao_reload = model._original_do_torchao_reload
for layer in model.modules():
info = get_layerwise_info(layer)
# Attention/MLA layers are processed after all other layers
if isinstance(layer, (Attention, MLAAttention)):
if info.load_numel > 0:
raise NotImplementedError(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
# No weights were loaded, place kernel tensors back
elif info.can_process() and info.load_numel <= 0:
_place_kernel_tensors(layer, info)
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to execess memory usage
# see Limitations(4)
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
logger.debug("%s: Delayed processing", layer.__class__.__name__)
_layerwise_process(layer, info)
info.reset()
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
"""
Finalize layer loading after all weights have been cached.
This function:
1. Materializes the layer onto the target device
2. Loads all cached weights
3. Runs quantization processing if applicable
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer(layer)
# Reset FP8 online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):
delattr(layer, "_already_called_process_weights_after_loading")
# Unwrap layerwise loading wrappers
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
# Load all cached weights into materialized layer (using original loaders)
for name, args in info.loaded_weights:
param = getattr(layer, name)
args.arguments["param"] = param
param.weight_loader(*args.args, **args.kwargs)
# Process weights (quantization, repacking, etc.)
# Attention/MLA are processed in `finalize_layerwise_reload`
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
_place_kernel_tensors(layer, info)
info.reset()
logger.debug("%s: Processed", layer.__class__.__name__)
def _get_original_loader(tensor: torch.Tensor) -> Callable:
"""Return the weight loader with any layerwise wrappers removed"""
loader = _get_weight_loader(tensor)
while loader.__name__ == "online_process_loader":
loader = loader.__wrapped__ # type: ignore[union-attr]
return loader
def _get_weight_loader(tensor: torch.Tensor):
return getattr(tensor, "weight_loader", default_weight_loader)
def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for name in get_layer_tensors(layer):
delattr(layer, name)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
layer.register_parameter(name, param)
for name, buffer in buffers.items():
layer.register_buffer(name, buffer)

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from .sanitize import restore_layer_refs, sanitize_layer_refs
from .types import LayerReloadingInfo, LayerTensors
from .utils import get_layer_params_buffers, get_layer_tensors
__all__ = [
"to_meta_tensor",
"materialize_meta_tensor",
"capture_layer_to_meta",
"restore_layer_on_meta",
"materialize_layer",
"get_numel_loaded",
]
SKIP_MODULES: set[str] = {"HadamardTransform"}
SKIP_TENSORS: set[str] = {
"_expert_map",
"expert_mask",
"expert_global_to_physical",
"expert_physical_to_global",
"expert_local_to_global",
}
def to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Convert a tensor to a meta tensor while preserving class and attributes."""
meta_tensor = tensor.data.to("meta")
meta_tensor.__class__ = tensor.__class__
meta_tensor.__dict__ = tensor.__dict__.copy()
return meta_tensor
def materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor:
"""
Materialize a meta tensor into an actual tensor on the current device.
Should be called within the torch device context for the given rank.
"""
tensor = torch.empty_strided(
size=tuple(meta_tensor.size()),
stride=tuple(meta_tensor.stride()),
dtype=meta_tensor.dtype,
requires_grad=False,
)
tensor.__class__ = meta_tensor.__class__
tensor.__dict__ = meta_tensor.__dict__.copy()
return tensor
def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors:
if layer.__class__.__name__ in SKIP_MODULES:
return ({}, {})
params, buffers = get_layer_params_buffers(layer)
return (
{
name: sanitize_layer_refs(to_meta_tensor(param), layer)
for name, param in params.items()
if name not in SKIP_TENSORS
},
{
name: sanitize_layer_refs(to_meta_tensor(buffer), layer)
for name, buffer in buffers.items()
if name not in SKIP_TENSORS
},
)
def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
"""Restore a layer to model format with tensors on the meta device"""
if layer.__class__.__name__ in SKIP_MODULES:
return
for name in get_layer_tensors(layer):
if name not in SKIP_TENSORS:
delattr(layer, name)
restore_params, restore_buffers = info.restore_metadata
for name, param in restore_params.items():
if name not in SKIP_TENSORS:
param = restore_layer_refs(param, layer)
layer.register_parameter(name, param)
for name, buffer in restore_buffers.items():
if name not in SKIP_TENSORS:
buffer = restore_layer_refs(buffer, layer)
layer.register_buffer(name, buffer)
def materialize_layer(layer: torch.nn.Module) -> None:
"""Materialize all meta tensors in a layer to actual tensors."""
if layer.__class__.__name__ in SKIP_MODULES:
return
for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS:
setattr(layer, name, materialize_meta_tensor(tensor))
class MetaCopyCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`.
Useful for keeping track of weight loading where underlying weights can be
arbitrarily transformed (such as with `narrow`) before calling copy.
Note: Assumes that copy kwargs are not used.
"""
def __init__(self):
super().__init__()
self.copied_numel = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.ops.aten.copy_.default and args[0].device.type == "meta":
assert args[0].numel() == args[1].numel()
self.copied_numel += args[0].numel()
return func(*args, **kwargs)
def get_numel_loaded(
weight_loader: Callable, args: inspect.BoundArguments
) -> tuple[int, object]:
"""
Determine how many elements would be loaded by a weight loader call.
:param weight loader: used to load weights
:param args: bound arguments to weight loader
:return: number of elements loaded by the weight loader, the return value of the
weight loader
"""
assert args.arguments["param"].device.type == "meta"
with MetaCopyCounter() as counter:
return_value = weight_loader(*args.args, **args.kwargs)
return counter.copied_numel, return_value

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import MethodType
import torch
__all__ = ["sanitize_layer_refs", "restore_layer_refs"]
layer_ref_sentinel = object()
def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
"""
Removes references to layer held by tensor attributes. Specifically, removes the
`__self__` attribute of weight loader methods attached to the tensor.
Used by `capture_layer_to_meta` to avoid circular references to layers in
`LAYERWISE_INFO`, leading to modules never being cleaned up. Without sanitation,
tensors will reference layers, and the WeakKeyDictionary will never evict entries,
even when the model is deleted.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer:
tensor.__dict__[key] = value.__func__.__get__(layer_ref_sentinel)
return tensor
def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
"""
Restores references to layer held by tensor attributes.
Used by `restore_layer_on_meta` to add back layer references, allowing for proper
weight loading.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel:
tensor.__dict__[key] = value.__func__.__get__(layer)
return tensor

View File

@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import wraps
from types import FunctionType
from typing import TYPE_CHECKING
import torch
from vllm.config import ModelConfig
from .layerwise import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
if TYPE_CHECKING:
from vllm.model_executor.models.utils import AutoWeightsLoader
__all__ = ["set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights"]
def set_torchao_reload_attrs(model: torch.nn.Module, model_config: ModelConfig):
model._do_torchao_reload = True
model._model_config = model_config
def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType):
"""
Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights.
Only applies to torchao quantized models. Assumes that all model weights are
loaded within a single weights iterator (cannot perform batched updates)
"""
@wraps(original_load_weights)
def patched_model_load_weights(
self: "AutoWeightsLoader",
weights: Iterable[tuple[str, torch.Tensor]],
*args,
**kwargs,
):
model = self.module
if not getattr(model, "_do_torchao_reload", False):
return original_load_weights(self, weights, *args, **kwargs)
initialize_layerwise_reload(model)
loaded_weights = original_load_weights(self, weights, *args, **kwargs)
finalize_layerwise_reload(model, model._model_config)
return loaded_weights
return patched_model_load_weights

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from inspect import BoundArguments
import torch
__all__ = ["LayerTensors", "LayerReloadingInfo"]
# encodes both parameters and buffers separately
LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
@dataclass
class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
# kernel format (device)
kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {}))
# track how many restored elements are ready for loading
load_numel: int = 0
load_numel_total: int | None = None
# stores arguments and tensors ready for loading
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)
def reset(self):
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
def can_process(self) -> bool:
return self.load_numel_total is not None

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from .types import LayerTensors
__all__ = [
"get_layer_tensors",
"get_layer_params_buffers",
"get_layer_size",
]
def get_layer_tensors(layer: torch.nn.Module) -> dict[str, torch.Tensor]:
"""Get all parameters and buffers from a module as a dict."""
params, buffers = get_layer_params_buffers(layer)
return params | buffers
def get_layer_params_buffers(layer: torch.nn.Module) -> LayerTensors:
"""Get all parameters and buffers of a module as a tuple of dicts."""
return (
{name: param for name, param in layer._parameters.items() if param is not None},
{name: buffer for name, buffer in layer._buffers.items() if buffer is not None},
)
def get_layer_size(layer: torch.nn.Module) -> int:
"""Calculate total number of elements across all tensors in a layer."""
return sum(tensor.numel() for tensor in get_layer_tensors(layer).values())