Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
146
vllm/model_executor/model_loader/reload/meta.py
Normal file
146
vllm/model_executor/model_loader/reload/meta.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
from .sanitize import restore_layer_refs, sanitize_layer_refs
|
||||
from .types import LayerReloadingInfo, LayerTensors
|
||||
from .utils import get_layer_params_buffers, get_layer_tensors
|
||||
|
||||
__all__ = [
|
||||
"to_meta_tensor",
|
||||
"materialize_meta_tensor",
|
||||
"capture_layer_to_meta",
|
||||
"restore_layer_on_meta",
|
||||
"materialize_layer",
|
||||
"get_numel_loaded",
|
||||
]
|
||||
|
||||
SKIP_MODULES: set[str] = {"HadamardTransform"}
|
||||
|
||||
SKIP_TENSORS: set[str] = {
|
||||
"_expert_map",
|
||||
"expert_mask",
|
||||
"expert_global_to_physical",
|
||||
"expert_physical_to_global",
|
||||
"expert_local_to_global",
|
||||
}
|
||||
|
||||
|
||||
def to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a tensor to a meta tensor while preserving class and attributes."""
|
||||
meta_tensor = tensor.data.to("meta")
|
||||
meta_tensor.__class__ = tensor.__class__
|
||||
meta_tensor.__dict__ = tensor.__dict__.copy()
|
||||
return meta_tensor
|
||||
|
||||
|
||||
def materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Materialize a meta tensor into an actual tensor on the current device.
|
||||
Should be called within the torch device context for the given rank.
|
||||
"""
|
||||
tensor = torch.empty_strided(
|
||||
size=tuple(meta_tensor.size()),
|
||||
stride=tuple(meta_tensor.stride()),
|
||||
dtype=meta_tensor.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
tensor.__class__ = meta_tensor.__class__
|
||||
tensor.__dict__ = meta_tensor.__dict__.copy()
|
||||
return tensor
|
||||
|
||||
|
||||
def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors:
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return ({}, {})
|
||||
|
||||
params, buffers = get_layer_params_buffers(layer)
|
||||
return (
|
||||
{
|
||||
name: sanitize_layer_refs(to_meta_tensor(param), layer)
|
||||
for name, param in params.items()
|
||||
if name not in SKIP_TENSORS
|
||||
},
|
||||
{
|
||||
name: sanitize_layer_refs(to_meta_tensor(buffer), layer)
|
||||
for name, buffer in buffers.items()
|
||||
if name not in SKIP_TENSORS
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
"""Restore a layer to model format with tensors on the meta device"""
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return
|
||||
|
||||
for name in get_layer_tensors(layer):
|
||||
if name not in SKIP_TENSORS:
|
||||
delattr(layer, name)
|
||||
|
||||
restore_params, restore_buffers = info.restore_metadata
|
||||
for name, param in restore_params.items():
|
||||
if name not in SKIP_TENSORS:
|
||||
param = restore_layer_refs(param, layer)
|
||||
layer.register_parameter(name, param)
|
||||
|
||||
for name, buffer in restore_buffers.items():
|
||||
if name not in SKIP_TENSORS:
|
||||
buffer = restore_layer_refs(buffer, layer)
|
||||
layer.register_buffer(name, buffer)
|
||||
|
||||
|
||||
def materialize_layer(layer: torch.nn.Module) -> None:
|
||||
"""Materialize all meta tensors in a layer to actual tensors."""
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return
|
||||
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if name not in SKIP_TENSORS:
|
||||
setattr(layer, name, materialize_meta_tensor(tensor))
|
||||
|
||||
|
||||
class MetaCopyCounter(TorchDispatchMode):
|
||||
"""
|
||||
Tracks total number of elements modified with `copy_`.
|
||||
|
||||
Useful for keeping track of weight loading where underlying weights can be
|
||||
arbitrarily transformed (such as with `narrow`) before calling copy.
|
||||
|
||||
Note: Assumes that copy kwargs are not used.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.copied_numel = 0
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.ops.aten.copy_.default and args[0].device.type == "meta":
|
||||
assert args[0].numel() == args[1].numel()
|
||||
self.copied_numel += args[0].numel()
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def get_numel_loaded(
|
||||
weight_loader: Callable, args: inspect.BoundArguments
|
||||
) -> tuple[int, object]:
|
||||
"""
|
||||
Determine how many elements would be loaded by a weight loader call.
|
||||
|
||||
:param weight loader: used to load weights
|
||||
:param args: bound arguments to weight loader
|
||||
:return: number of elements loaded by the weight loader, the return value of the
|
||||
weight loader
|
||||
"""
|
||||
assert args.arguments["param"].device.type == "meta"
|
||||
with MetaCopyCounter() as counter:
|
||||
return_value = weight_loader(*args.args, **args.kwargs)
|
||||
return counter.copied_numel, return_value
|
||||
Reference in New Issue
Block a user