update
This commit is contained in:
31
vllm/model_executor/model_loader/reload/utils.py
Normal file
31
vllm/model_executor/model_loader/reload/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from .types import LayerTensors
|
||||
|
||||
__all__ = [
|
||||
"get_layer_tensors",
|
||||
"get_layer_params_buffers",
|
||||
"get_layer_size",
|
||||
]
|
||||
|
||||
|
||||
def get_layer_tensors(layer: torch.nn.Module) -> dict[str, torch.Tensor]:
|
||||
"""Get all parameters and buffers from a module as a dict."""
|
||||
params, buffers = get_layer_params_buffers(layer)
|
||||
return params | buffers
|
||||
|
||||
|
||||
def get_layer_params_buffers(layer: torch.nn.Module) -> LayerTensors:
|
||||
"""Get all parameters and buffers of a module as a tuple of dicts."""
|
||||
return (
|
||||
{name: param for name, param in layer._parameters.items() if param is not None},
|
||||
{name: buffer for name, buffer in layer._buffers.items() if buffer is not None},
|
||||
)
|
||||
|
||||
|
||||
def get_layer_size(layer: torch.nn.Module) -> int:
|
||||
"""Calculate total number of elements across all tensors in a layer."""
|
||||
return sum(tensor.numel() for tensor in get_layer_tensors(layer).values())
|
||||
Reference in New Issue
Block a user