Revert "Add simple CPU offloading support" (#2252)

We'll re-add the commit to correctly ack Kaichao's authorship
This commit is contained in:
Ying Sheng
2024-11-28 23:36:55 -08:00
committed by GitHub
parent 4f2ee48ed1
commit 4057ea82c9
9 changed files with 29 additions and 173 deletions

View File

@@ -32,7 +32,7 @@ import time
import warnings
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import psutil
@@ -45,7 +45,6 @@ from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import (
@@ -193,94 +192,6 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
return free_gpu_memory / (1 << 30)
def is_pin_memory_available() -> bool:
return torch.cuda.is_available()
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device
if device == torch.device("cpu"):
return module
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module
pin_memory = is_pin_memory_available()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True
if offloaded_parameters:
original_forward = module.forward
def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module, device_state, args=args, kwargs=kwargs)
module.forward = forward
return output
module.forward = forward
return module
class LayerFn(Protocol):
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str = "",
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function"""
modules = torch.nn.ModuleList(
[
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
for idx in range(num_hidden_layers)
]
)
return modules
def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)