42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import torch
|
|
|
|
|
|
def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor):
|
|
assert dst.dtype == src.dtype, "Tensors must have the same dtype"
|
|
|
|
# update tensor shape and stride
|
|
dst.as_strided_(src.shape, src.stride())
|
|
|
|
# If not the same underlying storage move tensor data
|
|
if dst.data_ptr() != src.data_ptr():
|
|
dst.copy_(src)
|
|
del src
|
|
|
|
|
|
# Newly generated tensors need to replace existing tensors that are
|
|
# already registered as parameters by vLLM (and won't be freed)
|
|
def replace_parameter(
|
|
mod: torch.nn.Module, name: str, new: torch.Tensor | torch.nn.Parameter
|
|
) -> None:
|
|
old = getattr(mod, name)
|
|
if (
|
|
type(old) is type(new)
|
|
and old.dtype == new.dtype
|
|
and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
|
|
):
|
|
# If we can just update in-place to avoid re-registering
|
|
# can be faster if the underlying storage is the same
|
|
update_tensor_inplace(old, new)
|
|
else:
|
|
# Fallback re-register parameter, convert to Parameter if necessary
|
|
# this not only ensures we don't register a tensor as a parameter, but
|
|
# also ensures that all parameter subclasses get re-registered as
|
|
# parameters for `torch.compile` compatibility
|
|
if not isinstance(new, torch.nn.Parameter):
|
|
new = torch.nn.Parameter(new, requires_grad=False)
|
|
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|