Refactor weight offloading logic (#8521)

This commit is contained in:
fzyzcjy
2025-08-21 18:48:13 +08:00
committed by GitHub
parent de4990a5b2
commit 55d336cb08
3 changed files with 141 additions and 74 deletions

View File

@@ -438,72 +438,6 @@ def is_pin_memory_available() -> bool:
return torch.cuda.is_available()
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
if (params := next(module.parameters(), None)) is None:
return module
device = params.device
if device == torch.device("cpu"):
return module
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module
pin_memory = is_pin_memory_available()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True
if offloaded_parameters:
original_forward = module.forward
def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module, device_state, args=args, kwargs=kwargs)
module.forward = forward
return output
module.forward = forward
return module
class LayerFn(Protocol):
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
@@ -516,11 +450,13 @@ def make_layers(
pp_size: Optional[int] = None,
prefix: str = "",
return_tuple: bool = False,
offloader_kwargs: Dict[str, Any] = {},
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function"""
# circula imports
from sglang.srt.distributed import get_pp_indices
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.offloader import get_offloader
assert not pp_size or num_hidden_layers >= pp_size
start_layer, end_layer = (
@@ -534,10 +470,13 @@ def make_layers(
)
modules = torch.nn.ModuleList(
[PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
+ [
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
for idx in range(start_layer, end_layer)
]
+ get_offloader().wrap_modules(
(
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
for idx in range(start_layer, end_layer)
),
**offloader_kwargs,
)
+ [
PPMissingLayer(return_tuple=return_tuple)
for _ in range(end_layer, num_hidden_layers)