update
This commit is contained in:
23
vllm/model_executor/offloader/__init__.py
Normal file
23
vllm/model_executor/offloader/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Model parameter offloading infrastructure."""
|
||||
|
||||
from vllm.model_executor.offloader.base import (
|
||||
BaseOffloader,
|
||||
NoopOffloader,
|
||||
create_offloader,
|
||||
get_offloader,
|
||||
set_offloader,
|
||||
)
|
||||
from vllm.model_executor.offloader.prefetch import PrefetchOffloader
|
||||
from vllm.model_executor.offloader.uva import UVAOffloader
|
||||
|
||||
__all__ = [
|
||||
"BaseOffloader",
|
||||
"NoopOffloader",
|
||||
"UVAOffloader",
|
||||
"PrefetchOffloader",
|
||||
"create_offloader",
|
||||
"get_offloader",
|
||||
"set_offloader",
|
||||
]
|
||||
145
vllm/model_executor/offloader/base.py
Normal file
145
vllm/model_executor/offloader/base.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/offloader.py
|
||||
"""Base classes for model parameter offloading."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import OffloadConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
"""
|
||||
class relation:
|
||||
|
||||
BaseOffloader (ABC)
|
||||
* implemented by: UVAOffloader
|
||||
* implemented by: PrefetchOffloader
|
||||
* uses: _ModuleOffloader
|
||||
* uses: _BaseParamOffloader (ABC)
|
||||
* implemented by: _CpuParamOffloader
|
||||
"""
|
||||
|
||||
|
||||
class BaseOffloader(ABC):
|
||||
"""Base class for model parameter offloading strategies.
|
||||
|
||||
Offloaders control how model parameters are stored and loaded during
|
||||
inference. Different strategies trade memory for compute/transfer time.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wrap_modules(
|
||||
self,
|
||||
modules_generator: Generator[nn.Module, None, None],
|
||||
) -> list[nn.Module]:
|
||||
"""Wrap modules with offloading logic.
|
||||
|
||||
Args:
|
||||
modules_generator: Generator yielding modules to potentially offload.
|
||||
|
||||
Returns:
|
||||
List of modules, potentially with offloading hooks installed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def post_init(self):
|
||||
"""Called after model construction completes.
|
||||
|
||||
Offloaders can use this to:
|
||||
- Finalize parameter storage
|
||||
- Start initial prefetching
|
||||
- Allocate shared resources
|
||||
"""
|
||||
return
|
||||
|
||||
def sync_prev_onload(self) -> None: # noqa: B027
|
||||
"""Sync previous onload operations. Override in subclasses."""
|
||||
pass
|
||||
|
||||
def join_after_forward(self) -> None: # noqa: B027
|
||||
"""Join streams after forward. Override in subclasses."""
|
||||
pass
|
||||
|
||||
def _wait_for_layer(self, layer_idx: int) -> None: # noqa: B027
|
||||
"""Wait for layer prefetch. Override in subclasses."""
|
||||
pass
|
||||
|
||||
def _start_prefetch(self, layer_idx: int) -> None: # noqa: B027
|
||||
"""Start layer prefetch. Override in subclasses."""
|
||||
pass
|
||||
|
||||
|
||||
class NoopOffloader(BaseOffloader):
|
||||
"""No-op offloader that returns modules as-is without any offloading."""
|
||||
|
||||
def wrap_modules(
|
||||
self,
|
||||
modules_generator: Generator[nn.Module, None, None],
|
||||
) -> list[nn.Module]:
|
||||
"""Return modules unchanged."""
|
||||
return list(modules_generator)
|
||||
|
||||
|
||||
# Global singleton offloader instance (defaults to no-op).
|
||||
_instance: BaseOffloader = NoopOffloader()
|
||||
|
||||
|
||||
def get_offloader() -> BaseOffloader:
|
||||
"""Get the global offloader instance."""
|
||||
return _instance
|
||||
|
||||
|
||||
def set_offloader(instance: BaseOffloader) -> None:
|
||||
"""Set the global offloader instance."""
|
||||
global _instance
|
||||
_instance = instance
|
||||
logger.info("Offloader set to %s", type(instance).__name__)
|
||||
|
||||
|
||||
def create_offloader(offload_config: "OffloadConfig") -> BaseOffloader:
|
||||
"""Create an offloader based on the offload configuration.
|
||||
|
||||
Uses the explicit ``offload_backend`` selector. When set to ``"auto"``,
|
||||
selects prefetch if ``offload_group_size > 0``, UVA if
|
||||
``cpu_offload_gb > 0``, otherwise noop.
|
||||
"""
|
||||
from vllm.model_executor.offloader.prefetch import PrefetchOffloader
|
||||
from vllm.model_executor.offloader.uva import UVAOffloader
|
||||
|
||||
backend = offload_config.offload_backend
|
||||
uva = offload_config.uva
|
||||
prefetch = offload_config.prefetch
|
||||
|
||||
if backend == "auto":
|
||||
if prefetch.offload_group_size > 0:
|
||||
backend = "prefetch"
|
||||
elif uva.cpu_offload_gb > 0:
|
||||
backend = "uva"
|
||||
else:
|
||||
return NoopOffloader()
|
||||
|
||||
if backend == "prefetch":
|
||||
return PrefetchOffloader(
|
||||
group_size=prefetch.offload_group_size,
|
||||
num_in_group=prefetch.offload_num_in_group,
|
||||
prefetch_step=prefetch.offload_prefetch_step,
|
||||
offload_params=prefetch.offload_params,
|
||||
mode="cpu",
|
||||
)
|
||||
elif backend == "uva":
|
||||
return UVAOffloader(
|
||||
cpu_offload_max_bytes=int(uva.cpu_offload_gb * 1024**3),
|
||||
cpu_offload_params=uva.cpu_offload_params,
|
||||
)
|
||||
else:
|
||||
return NoopOffloader()
|
||||
704
vllm/model_executor/offloader/prefetch.py
Normal file
704
vllm/model_executor/offloader/prefetch.py
Normal file
@@ -0,0 +1,704 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/offloader.py
|
||||
"""Prefetch-based CPU offloading with async prefetching.
|
||||
|
||||
Uses static buffers and event-based stream forking for torch.compile +
|
||||
CUDA graph compatibility. Events allow the copy stream to join CUDA
|
||||
graph captures, ensuring H2D copies are properly captured.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Import prefetch_ops to register custom ops at module load time
|
||||
import vllm.model_executor.offloader.prefetch_ops # noqa: F401
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import BaseOffloader
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamInfo:
|
||||
"""Metadata about an offloaded parameter."""
|
||||
|
||||
name: str
|
||||
shape: tuple[int, ...]
|
||||
stride: tuple[int, ...]
|
||||
dtype: torch.dtype
|
||||
|
||||
@property
|
||||
def key(self) -> tuple[str, tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
"""Unique key for buffer pool grouping.
|
||||
|
||||
Includes parameter name to prevent different parameters with the same
|
||||
shape from sharing buffers within the same layer. Parameters with the
|
||||
same name across different layers will share buffers (via slots).
|
||||
|
||||
Includes stride because parameters with same shape but different
|
||||
strides need separate buffers to preserve memory layout.
|
||||
"""
|
||||
return (self.name, self.shape, self.stride, self.dtype)
|
||||
|
||||
@property
|
||||
def num_bytes(self) -> int:
|
||||
"""Size in bytes."""
|
||||
numel = 1
|
||||
for dim in self.shape:
|
||||
numel *= dim
|
||||
return numel * torch.finfo(self.dtype).bits // 8
|
||||
|
||||
|
||||
class StaticBufferPool:
|
||||
"""Pre-allocated GPU buffer pool for offloaded parameters.
|
||||
|
||||
Allocates slot_capacity copies of each unique parameter
|
||||
(name, shape, stride, dtype), allowing for double/triple buffering
|
||||
during prefetch.
|
||||
|
||||
Buffer slots are reused circularly: layer N uses slot (N % slot_capacity).
|
||||
|
||||
The key includes parameter name to prevent different parameters within
|
||||
the same layer from sharing buffers. Parameters with the same name
|
||||
across different layers share buffers via the slot mechanism.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
param_infos: list[ParamInfo],
|
||||
slot_capacity: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.slot_capacity = slot_capacity
|
||||
self.total_bytes = 0
|
||||
self._device = device
|
||||
|
||||
# Group by (shape, stride, dtype) - only allocate unique combinations
|
||||
unique_params: dict[tuple, ParamInfo] = {}
|
||||
for info in param_infos:
|
||||
if info.key not in unique_params:
|
||||
unique_params[info.key] = info
|
||||
|
||||
# Allocate buffers: key -> list of tensors (one per slot)
|
||||
self._buffers: dict[tuple, list[torch.Tensor]] = {}
|
||||
for key, info in unique_params.items():
|
||||
slot_tensors = []
|
||||
for _ in range(slot_capacity):
|
||||
# Use empty_strided to preserve parameter's memory layout
|
||||
buf = torch.empty_strided(
|
||||
size=info.shape,
|
||||
stride=info.stride,
|
||||
dtype=info.dtype,
|
||||
device=device,
|
||||
)
|
||||
slot_tensors.append(buf)
|
||||
self.total_bytes += info.num_bytes
|
||||
self._buffers[key] = slot_tensors
|
||||
|
||||
logger.debug(
|
||||
"[StaticBufferPool] Allocated %d unique (name, shape, stride, dtype), "
|
||||
"%d slots each, total %.4f GB",
|
||||
len(unique_params),
|
||||
slot_capacity,
|
||||
self.total_bytes / 1e9,
|
||||
)
|
||||
|
||||
def get_buffer(
|
||||
self,
|
||||
name: str,
|
||||
shape: tuple[int, ...],
|
||||
stride: tuple[int, ...],
|
||||
dtype: torch.dtype,
|
||||
slot_idx: int,
|
||||
) -> torch.Tensor:
|
||||
"""Get a static buffer for the given name/shape/stride/dtype/slot."""
|
||||
key = (name, shape, stride, dtype)
|
||||
return self._buffers[key][slot_idx % self.slot_capacity]
|
||||
|
||||
|
||||
class PrefetchOffloader(BaseOffloader):
|
||||
"""Prefetching-based offloader with group-based layer selection.
|
||||
|
||||
Groups layers and uses async H2D prefetch to hide transfer latency.
|
||||
Uses static buffers and stream synchronization for torch.compile and
|
||||
CUDA graph compatibility.
|
||||
|
||||
Args:
|
||||
group_size: Group every N layers together.
|
||||
num_in_group: Offload this many layers per group (last N of each group).
|
||||
prefetch_step: Number of layers to prefetch ahead.
|
||||
mode: Offload mode ("cpu" is currently supported).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_size: int,
|
||||
num_in_group: int,
|
||||
prefetch_step: int,
|
||||
offload_params: set[str] | None = None,
|
||||
mode: str = "cpu",
|
||||
):
|
||||
self.group_size = group_size
|
||||
self.num_in_group = num_in_group
|
||||
self.prefetch_step = prefetch_step
|
||||
self.offload_params = offload_params or set()
|
||||
self.mode = mode
|
||||
|
||||
# Copy stream for async H2D transfers
|
||||
self.copy_stream = torch.cuda.Stream()
|
||||
|
||||
# Module offloaders and buffer pool (populated in wrap_modules/post_init)
|
||||
self.module_offloaders: list[_ModuleOffloader] = []
|
||||
self.buffer_pool: StaticBufferPool | None = None
|
||||
self.total_offloaded_bytes = 0
|
||||
|
||||
def wrap_modules(
|
||||
self,
|
||||
modules_generator: Generator[nn.Module, None, None],
|
||||
) -> list[nn.Module]:
|
||||
"""Wrap modules with prefetch offloading logic."""
|
||||
assert len(self.module_offloaders) == 0, (
|
||||
"wrap_modules should only be called once"
|
||||
)
|
||||
|
||||
all_modules = []
|
||||
offload_modules = []
|
||||
|
||||
for module_index, module in enumerate(modules_generator):
|
||||
all_modules.append(module)
|
||||
|
||||
# Select layers to offload based on group pattern
|
||||
# Offload last num_in_group layers of each group_size
|
||||
if module_index % self.group_size >= self.group_size - self.num_in_group:
|
||||
if self.offload_params:
|
||||
whitelist = [
|
||||
name
|
||||
for name, _ in module.named_parameters()
|
||||
if any(f".{p}." in f".{name}." for p in self.offload_params)
|
||||
]
|
||||
else:
|
||||
whitelist = [name for name, _ in module.named_parameters()]
|
||||
|
||||
if not whitelist:
|
||||
continue # skip layers with no matching params
|
||||
|
||||
offload_modules.append(module)
|
||||
self.module_offloaders.append(
|
||||
_ModuleOffloader(
|
||||
mode=self.mode,
|
||||
module=module,
|
||||
copy_stream=self.copy_stream,
|
||||
whitelist_param_names=whitelist,
|
||||
layer_idx=len(self.module_offloaders),
|
||||
)
|
||||
)
|
||||
|
||||
for index, module in enumerate(offload_modules):
|
||||
self._hook_module_forward(index, module)
|
||||
|
||||
return all_modules
|
||||
|
||||
def _hook_module_forward(self, index: int, module: nn.Module):
|
||||
"""Hook module's forward with torch.compile-compatible sync."""
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
# Temporarily restore original forward to avoid recursion
|
||||
module.forward = original_forward
|
||||
|
||||
# Wait for this layer's prefetch to complete
|
||||
# mutates_args on input_tensor creates data dependency for torch.compile
|
||||
input_tensor = args[0] if args else kwargs.get("hidden_states")
|
||||
torch.ops.vllm.wait_prefetch(input_tensor, index)
|
||||
|
||||
# No parameter swapping needed - parameters already point to
|
||||
# GPU static buffers (set in assign_static_buffer)
|
||||
output = original_forward(*args, **kwargs)
|
||||
|
||||
# Start prefetch for next layer (circular)
|
||||
# mutates_args on output_tensor creates ordering dependency
|
||||
next_index = (index + self.prefetch_step) % len(self.module_offloaders)
|
||||
# Handle tuple output (e.g., (hidden_states, residual))
|
||||
if isinstance(output, tuple):
|
||||
torch.ops.vllm.start_prefetch(output[0], next_index)
|
||||
else:
|
||||
torch.ops.vllm.start_prefetch(output, next_index)
|
||||
|
||||
# No explicit offload needed - static buffers are reused implicitly
|
||||
|
||||
# Restore hooked forward
|
||||
module.forward = forward
|
||||
return output
|
||||
|
||||
module.forward = forward
|
||||
|
||||
def _wait_for_layer(self, layer_idx: int):
|
||||
"""Called by custom op - wait for copy to complete.
|
||||
|
||||
Synchronization strategy:
|
||||
- During CUDA graph capture: use event-based wait (graph-compatible)
|
||||
- Outside capture (warmup/eager): use wait_stream (more robust)
|
||||
|
||||
During capture, we skip wait for pre-capture prefetches because:
|
||||
1. sync_before_graph_capture() ensures pre-capture work is complete
|
||||
2. We can't wait on pre-capture events during capture (isolation error)
|
||||
"""
|
||||
offloader = self.module_offloaders[layer_idx]
|
||||
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
# During capture, skip wait for pre-capture prefetches.
|
||||
# sync_before_graph_capture() ensures pre-capture work is complete.
|
||||
if not offloader._prefetch_in_capture:
|
||||
return
|
||||
# Event-based wait for in-capture prefetches (graph-compatible)
|
||||
torch.cuda.current_stream().wait_event(offloader._copy_done_event)
|
||||
# Mark that this prefetch has been waited on (joined).
|
||||
offloader._prefetch_in_capture = False
|
||||
else:
|
||||
if offloader._event_valid_for_eager:
|
||||
# Use per-layer event to only wait for THIS layer's copy,
|
||||
# allowing other layers' prefetches to run concurrently.
|
||||
torch.cuda.current_stream().wait_event(offloader._copy_done_event)
|
||||
else:
|
||||
# Event not usable (unrecorded or recorded during capture).
|
||||
# Fall back to wait_stream to drain all copy_stream work.
|
||||
torch.cuda.current_stream().wait_stream(self.copy_stream)
|
||||
|
||||
def sync_prev_onload(self):
|
||||
"""Sync previous onload operations.
|
||||
|
||||
Ensures any H2D copies in flight on copy_stream complete before
|
||||
the compute stream continues. Call this before CUDA graph
|
||||
capture/replay or when synchronization is needed.
|
||||
"""
|
||||
torch.cuda.current_stream().wait_stream(self.copy_stream)
|
||||
|
||||
def _start_prefetch(self, layer_idx: int):
|
||||
"""Called by custom op - start async copy to static buffer."""
|
||||
offloader = self.module_offloaders[layer_idx]
|
||||
offloader.start_onload_to_static()
|
||||
|
||||
def join_after_forward(self):
|
||||
"""Join copy_stream after model forward completes.
|
||||
|
||||
Call this after the model forward pass but before CUDA graph capture
|
||||
ends. This ensures copy_stream is rejoined for any prefetches started
|
||||
during the forward pass.
|
||||
|
||||
We join ALL layers that have _prefetch_in_capture=True, meaning their
|
||||
prefetch was started during capture but not yet waited on (joined).
|
||||
This handles both full and piecewise cudagraph modes correctly:
|
||||
- Full mode: joins layers 0..prefetch_step-1 (prefetched by last layers)
|
||||
- Piecewise mode: joins only layers prefetched by THIS subgraph's layers
|
||||
"""
|
||||
if not self.module_offloaders:
|
||||
return
|
||||
# Join all layers whose prefetch was started in capture but not waited on
|
||||
for offloader in self.module_offloaders:
|
||||
if offloader._prefetch_in_capture:
|
||||
torch.cuda.current_stream().wait_event(offloader._copy_done_event)
|
||||
offloader._prefetch_in_capture = False
|
||||
|
||||
def post_init(self):
|
||||
"""Allocate static buffer pool and start initial prefetches.
|
||||
|
||||
Note: Parameters have already been offloaded to CPU during wrap_modules()
|
||||
(in _CpuParamOffloader.__init__), so GPU memory is available for the
|
||||
static buffer pool.
|
||||
"""
|
||||
# Sync CPU storage with current param.data BEFORE collecting param info.
|
||||
# This is needed because process_weights_after_loading may have:
|
||||
# 1. Transformed weights (quantization, transpose, etc.)
|
||||
# 2. Created new CPU tensors via device_loading_context
|
||||
# Our _cpu_storage would be stale otherwise.
|
||||
for offloader in self.module_offloaders:
|
||||
offloader.sync_cpu_storage()
|
||||
|
||||
# Collect parameter info (now using synced CPU storage)
|
||||
param_infos: list[ParamInfo] = []
|
||||
device: torch.device | None = None
|
||||
|
||||
for offloader in self.module_offloaders:
|
||||
param_infos.extend(offloader.get_param_infos())
|
||||
if device is None:
|
||||
device = offloader.device
|
||||
|
||||
if device is None:
|
||||
# No modules to offload
|
||||
return
|
||||
|
||||
# Allocate static buffer pool
|
||||
self.buffer_pool = StaticBufferPool(
|
||||
param_infos=param_infos,
|
||||
slot_capacity=self.prefetch_step,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Assign buffer slots and point parameters to GPU buffers
|
||||
for idx, offloader in enumerate(self.module_offloaders):
|
||||
slot_idx = idx % self.prefetch_step
|
||||
offloader.assign_buffer_slot(self.buffer_pool, slot_idx)
|
||||
|
||||
# Collect offloaded bytes
|
||||
for offloader in self.module_offloaders:
|
||||
offloader.post_init()
|
||||
self.total_offloaded_bytes += offloader.offloaded_bytes
|
||||
|
||||
logger.info_once(
|
||||
f"[PrefetchOffloader] Initialized {len(self.module_offloaders)} modules. "
|
||||
f"Total GPU memory saved: {self.total_offloaded_bytes / 1e9:.4f} GB, "
|
||||
f"Static buffer pool: {self.buffer_pool.total_bytes / 1e9:.4f} GB "
|
||||
f"(group_size={self.group_size}, num_in_group={self.num_in_group}, "
|
||||
f"prefetch_step={self.prefetch_step}, mode={self.mode})"
|
||||
)
|
||||
|
||||
# Start initial prefetches
|
||||
for i in range(min(self.prefetch_step, len(self.module_offloaders))):
|
||||
self.module_offloaders[i].start_onload_to_static()
|
||||
|
||||
|
||||
class _ModuleOffloader:
|
||||
"""Manages offloading for a single module.
|
||||
|
||||
Uses static buffers from a shared pool instead of dynamic allocation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str,
|
||||
module: nn.Module,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
whitelist_param_names: list[str],
|
||||
layer_idx: int,
|
||||
):
|
||||
self.mode = mode
|
||||
self.module = module
|
||||
self.device = next(module.parameters()).device
|
||||
self.copy_stream = copy_stream
|
||||
self.layer_idx = layer_idx
|
||||
self.offloaded_bytes = 0
|
||||
|
||||
# Event to signal when H2D copy to static buffer is complete.
|
||||
# Used for per-layer synchronization (both eager and capture modes).
|
||||
self._copy_done_event = torch.cuda.Event()
|
||||
|
||||
# Track whether _copy_done_event is valid for eager-mode wait_event.
|
||||
# False when: (1) never recorded, or (2) last recorded during a
|
||||
# cudagraph capture (events become invalid after capture ends).
|
||||
# In these cases we fall back to wait_stream.
|
||||
self._event_valid_for_eager = False
|
||||
|
||||
# Track if last prefetch was started during CUDA graph capture.
|
||||
# Used to skip wait_event during capture for pre-capture prefetches.
|
||||
self._prefetch_in_capture = False
|
||||
|
||||
assert self.device != torch.device("cpu"), (
|
||||
"Module parameters should not already be on CPU "
|
||||
"(offloader handles CPU placement)"
|
||||
)
|
||||
|
||||
# Buffer pool and slot (assigned in assign_buffer_slot)
|
||||
self._buffer_pool: StaticBufferPool | None = None
|
||||
self._buffer_slot_idx: int = 0
|
||||
|
||||
param_dict = dict(self.module.named_parameters())
|
||||
assert all(name in param_dict for name in whitelist_param_names), (
|
||||
f"Whitelist params {whitelist_param_names} not found in module params "
|
||||
f"{list(param_dict.keys())}"
|
||||
)
|
||||
|
||||
self._param_offloaders = {
|
||||
name: _BaseParamOffloader.create(mode, module=module, param_name=name)
|
||||
for name in whitelist_param_names
|
||||
}
|
||||
|
||||
def post_init(self):
|
||||
"""Collect total offloaded bytes (offloading already done in __init__)."""
|
||||
for param_offloader in self._param_offloaders.values():
|
||||
param_offloader.post_init()
|
||||
self.offloaded_bytes += param_offloader.offloaded_bytes
|
||||
|
||||
def sync_cpu_storage(self):
|
||||
"""Sync CPU storage with current param.data.
|
||||
|
||||
Called after process_weights_after_loading to ensure _cpu_storage
|
||||
contains the final processed weights, not stale pre-loading data.
|
||||
"""
|
||||
for param_offloader in self._param_offloaders.values():
|
||||
param_offloader.sync_cpu_storage()
|
||||
|
||||
def get_param_infos(self) -> list[ParamInfo]:
|
||||
"""Get parameter metadata for buffer pool allocation.
|
||||
|
||||
Note: sync_cpu_storage() must be called before this method to ensure
|
||||
_cpu_storage reflects the final processed weights (after quantization).
|
||||
"""
|
||||
infos = []
|
||||
for name, offloader in self._param_offloaders.items():
|
||||
cpu_storage = offloader._cpu_storage
|
||||
assert cpu_storage is not None, "CPU storage not initialized"
|
||||
infos.append(
|
||||
ParamInfo(
|
||||
name=name,
|
||||
shape=tuple(cpu_storage.shape),
|
||||
stride=tuple(cpu_storage.stride()),
|
||||
dtype=cpu_storage.dtype,
|
||||
)
|
||||
)
|
||||
return infos
|
||||
|
||||
def assign_buffer_slot(self, pool: StaticBufferPool, slot_idx: int):
|
||||
"""Assign this module to a buffer slot in the pool.
|
||||
|
||||
Also assigns static GPU buffers to each parameter offloader,
|
||||
which moves the parameter data to point to the GPU buffer.
|
||||
"""
|
||||
self._buffer_pool = pool
|
||||
self._buffer_slot_idx = slot_idx
|
||||
|
||||
# Assign static buffers to parameters
|
||||
# Use CPU storage shape/stride/dtype since param.data is now empty
|
||||
for name, offloader in self._param_offloaders.items():
|
||||
cpu_storage = offloader._cpu_storage
|
||||
assert cpu_storage is not None, "CPU storage not initialized"
|
||||
buffer = pool.get_buffer(
|
||||
name=name,
|
||||
shape=tuple(cpu_storage.shape),
|
||||
stride=tuple(cpu_storage.stride()),
|
||||
dtype=cpu_storage.dtype,
|
||||
slot_idx=slot_idx,
|
||||
)
|
||||
offloader.assign_static_buffer(buffer)
|
||||
|
||||
def start_onload_to_static(self):
|
||||
"""Start async copy from CPU storage to GPU buffer.
|
||||
|
||||
Uses event-based forking to join copy_stream to CUDA graph capture.
|
||||
This ensures H2D copies are properly captured when recording a graph.
|
||||
|
||||
IMPORTANT: We must wait for the compute stream before copying, because
|
||||
the previous layer's forward may still be using the buffer (GPU ops are
|
||||
async). Without this sync, we could overwrite the buffer while it's
|
||||
being read.
|
||||
"""
|
||||
assert self._buffer_pool is not None, "Buffer pool not assigned"
|
||||
|
||||
# Track if this prefetch is being captured (for _wait_for_layer logic)
|
||||
self._prefetch_in_capture = torch.cuda.is_current_stream_capturing()
|
||||
|
||||
# Fork: record event on compute stream, copy_stream waits on it
|
||||
# This joins copy_stream to any active CUDA graph capture
|
||||
fork_event = torch.cuda.Event()
|
||||
torch.cuda.current_stream().record_event(fork_event)
|
||||
self.copy_stream.wait_event(fork_event)
|
||||
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
for name, offloader in self._param_offloaders.items():
|
||||
cpu_storage = offloader._cpu_storage
|
||||
gpu_buffer = offloader._gpu_buffer
|
||||
assert cpu_storage is not None, "CPU storage not initialized"
|
||||
assert gpu_buffer is not None, "GPU buffer not assigned"
|
||||
assert not is_pin_memory_available() or cpu_storage.is_pinned(), (
|
||||
f"CPU storage for {name} is not pinned! "
|
||||
"non_blocking=True H2D copy from non-pinned memory "
|
||||
"causes stream synchronization that breaks "
|
||||
"event-based fork synchronization."
|
||||
)
|
||||
gpu_buffer.copy_(cpu_storage, non_blocking=True)
|
||||
|
||||
# Record completion event for _wait_for_layer to use
|
||||
self._copy_done_event.record(self.copy_stream)
|
||||
# Event is only valid for eager wait_event if recorded outside capture.
|
||||
# Events recorded during capture become invalid after capture ends.
|
||||
self._event_valid_for_eager = not torch.cuda.is_current_stream_capturing()
|
||||
|
||||
|
||||
class _BaseParamOffloader(ABC):
|
||||
"""Base class for parameter offloading strategies."""
|
||||
|
||||
# CPU storage for offloaded parameters (set by subclasses)
|
||||
_cpu_storage: torch.Tensor | None
|
||||
# GPU buffer reference (set by subclasses when using static buffers)
|
||||
_gpu_buffer: torch.Tensor | None
|
||||
|
||||
@staticmethod
|
||||
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
|
||||
"""Factory method to create appropriate offloader for mode."""
|
||||
if mode == "cpu":
|
||||
return _CpuParamOffloader(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown offload mode: {mode}")
|
||||
|
||||
def __init__(self, module: nn.Module, param_name: str):
|
||||
self._module = module
|
||||
self._param_name = param_name
|
||||
self.offloaded_bytes = 0
|
||||
self._cpu_storage = None
|
||||
self._gpu_buffer = None
|
||||
|
||||
@property
|
||||
def _param(self) -> nn.Parameter:
|
||||
"""Get the parameter being offloaded.
|
||||
|
||||
Supports dotted names (e.g. 'self_attn.qkv_proj.weight') by
|
||||
traversing the module hierarchy.
|
||||
"""
|
||||
obj: Any = self._module
|
||||
for attr in self._param_name.split("."):
|
||||
obj = getattr(obj, attr)
|
||||
return obj
|
||||
|
||||
def post_init(self):
|
||||
"""Initialize offloading (move parameter to storage)."""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def sync_cpu_storage(self) -> None:
|
||||
"""Sync CPU storage with current param.data.
|
||||
|
||||
Called after process_weights_after_loading to update _cpu_storage
|
||||
with the final processed weights.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def assign_static_buffer(self, gpu_buffer: torch.Tensor) -> None:
|
||||
"""Point parameter data to GPU static buffer."""
|
||||
pass
|
||||
|
||||
|
||||
class _CpuParamOffloader(_BaseParamOffloader):
|
||||
"""Offload parameter to pinned CPU memory.
|
||||
|
||||
Uses GPU static buffers as the actual parameter, with CPU storage
|
||||
kept separately. This ensures torch.compile sees GPU tensors at trace time.
|
||||
|
||||
The offloading happens in two phases:
|
||||
1. __init__() - copies GPU data to CPU, frees GPU memory immediately
|
||||
2. assign_static_buffer() - points param.data to GPU static buffer
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, param_name: str):
|
||||
super().__init__(module, param_name)
|
||||
self._cpu_storage: torch.Tensor | None = None
|
||||
self._gpu_buffer: torch.Tensor | None = None # Store reference to GPU buffer
|
||||
|
||||
# Offload to CPU immediately to free GPU memory during model loading
|
||||
self._offload_to_cpu_internal()
|
||||
|
||||
def _offload_to_cpu_internal(self):
|
||||
"""Copy parameter data to pinned CPU storage and free GPU memory.
|
||||
|
||||
This replaces param.data with CPU storage, allowing weight loading
|
||||
to continue writing to CPU memory. GPU memory is freed when the
|
||||
original GPU tensor is garbage collected.
|
||||
"""
|
||||
param = self._param
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
# Create pinned CPU storage and copy current GPU data
|
||||
self._cpu_storage = torch.empty_strided(
|
||||
size=param.data.size(),
|
||||
stride=param.data.stride(),
|
||||
dtype=param.data.dtype,
|
||||
layout=param.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self._cpu_storage.copy_(param.data)
|
||||
|
||||
self.offloaded_bytes = (
|
||||
self._cpu_storage.numel() * self._cpu_storage.element_size()
|
||||
)
|
||||
|
||||
# Point param.data to CPU storage - this allows weight loading to work
|
||||
# and frees GPU memory when the original GPU tensor is garbage collected
|
||||
param.data = self._cpu_storage
|
||||
|
||||
def _update_cpu_storage_from_param(self) -> None:
|
||||
"""Update _cpu_storage from current param.data, ensuring pinned memory.
|
||||
|
||||
After process_weights_after_loading, device_loading_context creates
|
||||
non-pinned CPU tensors via `p.data = p.data.to("cpu")`. Using
|
||||
non-pinned memory with `copy_(src, non_blocking=True)` causes CUDA to
|
||||
perform a stream synchronization before the copy, breaking the
|
||||
event-based fork synchronization and potentially allowing the copy
|
||||
to overwrite the GPU buffer while the compute stream still reads it.
|
||||
|
||||
This method ensures _cpu_storage always uses pinned memory when
|
||||
available, re-pinning if necessary.
|
||||
"""
|
||||
param = self._param
|
||||
|
||||
if param.data.device.type == "cpu":
|
||||
if is_pin_memory_available() and not param.data.is_pinned():
|
||||
pinned = torch.empty_strided(
|
||||
size=param.data.size(),
|
||||
stride=param.data.stride(),
|
||||
dtype=param.data.dtype,
|
||||
layout=param.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
)
|
||||
pinned.copy_(param.data)
|
||||
self._cpu_storage = pinned
|
||||
else:
|
||||
self._cpu_storage = param.data
|
||||
else:
|
||||
# param.data is on GPU - copy to existing CPU storage
|
||||
assert self._cpu_storage is not None
|
||||
self._cpu_storage.copy_(param.data)
|
||||
|
||||
def assign_static_buffer(self, gpu_buffer: torch.Tensor) -> None:
|
||||
"""Point parameter data to GPU static buffer.
|
||||
|
||||
This is called after weight loading AND process_weights_after_loading
|
||||
complete. At this point:
|
||||
- param.data may have been replaced by device_loading_context
|
||||
(which creates new CPU tensors after quantization processing)
|
||||
- We need to update _cpu_storage to point to current param.data
|
||||
so that prefetch copies the processed weights, not stale data
|
||||
- Then point param.data to the GPU buffer for torch.compile
|
||||
"""
|
||||
assert self._cpu_storage is not None, (
|
||||
"_offload_to_cpu_internal() must be called before assign_static_buffer()"
|
||||
)
|
||||
|
||||
# Get current parameter (may have been replaced by
|
||||
# process_weights_after_loading)
|
||||
param = self._param
|
||||
|
||||
# Update _cpu_storage to current param.data. This is critical because:
|
||||
# 1. process_weights_after_loading may transform weights (quantization)
|
||||
# 2. device_loading_context creates NEW CPU tensors when moving back
|
||||
# 3. Our old _cpu_storage would have pre-processed or stale data
|
||||
self._update_cpu_storage_from_param()
|
||||
|
||||
# Store reference to GPU buffer for use in start_onload
|
||||
self._gpu_buffer = gpu_buffer
|
||||
|
||||
# Point parameter to static GPU buffer - this is what torch.compile sees
|
||||
param.data = gpu_buffer
|
||||
|
||||
def sync_cpu_storage(self) -> None:
|
||||
"""Sync CPU storage with current param.data.
|
||||
|
||||
Called after process_weights_after_loading to update _cpu_storage
|
||||
with the final processed weights. This is critical because:
|
||||
1. process_weights_after_loading may transform weights (quantization)
|
||||
2. device_loading_context creates NEW CPU tensors when moving back
|
||||
3. Our old _cpu_storage would have pre-processed or stale data
|
||||
"""
|
||||
self._update_cpu_storage_from_param()
|
||||
|
||||
def post_init(self):
|
||||
"""No-op: offloading done in offload_to_cpu/assign_static_buffer."""
|
||||
pass
|
||||
94
vllm/model_executor/offloader/prefetch_ops.py
Normal file
94
vllm/model_executor/offloader/prefetch_ops.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Custom ops for prefetch offloader torch.compile + CUDA graph compatibility.
|
||||
|
||||
These ops use mutates_args to create data dependencies that prevent
|
||||
the compiler from reordering prefetch/sync operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
# --- wait_prefetch op ---
|
||||
|
||||
|
||||
def _wait_prefetch_impl(
|
||||
input_tensor: torch.Tensor,
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
"""Wait for prefetch of layer_idx to complete.
|
||||
|
||||
Synchronizes the compute stream with the copy stream to ensure
|
||||
the prefetched weights are ready for use.
|
||||
|
||||
Args:
|
||||
input_tensor: Input to the layer (e.g., hidden_states) - declared
|
||||
as mutated to create data dependency for torch.compile.
|
||||
layer_idx: Index of the layer to wait for.
|
||||
"""
|
||||
get_offloader()._wait_for_layer(layer_idx)
|
||||
|
||||
|
||||
def _wait_prefetch_fake(
|
||||
input_tensor: torch.Tensor,
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
"""Fake implementation for torch.compile tracing."""
|
||||
return
|
||||
|
||||
|
||||
# --- start_prefetch op ---
|
||||
|
||||
|
||||
def _start_prefetch_impl(
|
||||
output_tensor: torch.Tensor,
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
"""Start async prefetch of layer_idx weights.
|
||||
|
||||
Initiates H2D copy on the copy stream for the specified layer.
|
||||
|
||||
Args:
|
||||
output_tensor: Output from forward - declared as mutated to
|
||||
prevent torch.compile from reordering this op before the
|
||||
computation that produces output_tensor.
|
||||
layer_idx: Index of the layer to prefetch.
|
||||
"""
|
||||
get_offloader()._start_prefetch(layer_idx)
|
||||
|
||||
|
||||
def _start_prefetch_fake(
|
||||
output_tensor: torch.Tensor,
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
"""Fake implementation for torch.compile tracing."""
|
||||
return
|
||||
|
||||
|
||||
def register_prefetch_offloader_ops() -> None:
|
||||
"""Register custom ops for prefetch offloader.
|
||||
|
||||
Must be called before the ops are used. This is typically done
|
||||
at module import time.
|
||||
"""
|
||||
direct_register_custom_op(
|
||||
op_name="wait_prefetch",
|
||||
op_func=_wait_prefetch_impl,
|
||||
mutates_args=["input_tensor"],
|
||||
fake_impl=_wait_prefetch_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="start_prefetch",
|
||||
op_func=_start_prefetch_impl,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_start_prefetch_fake,
|
||||
)
|
||||
|
||||
|
||||
# Register ops at module import time
|
||||
register_prefetch_offloader_ops()
|
||||
140
vllm/model_executor/offloader/uva.py
Normal file
140
vllm/model_executor/offloader/uva.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""UVA-based CPU offloading using Unified Virtual Addressing."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import BaseOffloader
|
||||
from vllm.utils.mem_utils import format_gib
|
||||
from vllm.utils.platform_utils import is_pin_memory_available, is_uva_available
|
||||
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class UVAOffloader(BaseOffloader):
|
||||
"""Offloader using Unified Virtual Addressing (UVA) for zero-copy access.
|
||||
|
||||
This offloader moves parameters to pinned CPU memory and creates CUDA views
|
||||
using UVA. The GPU can then directly access the CPU memory without explicit
|
||||
transfers, at the cost of PCIe bandwidth (slower than GPU memory).
|
||||
|
||||
When UVA is disabled via env var, falls back to a functional_call-based
|
||||
approach that moves parameters on-demand.
|
||||
|
||||
Args:
|
||||
cpu_offload_max_bytes: Maximum bytes to offload to CPU.
|
||||
cpu_offload_params: Set of parameter name segments to selectively
|
||||
offload. If empty, all parameters are eligible up to the byte limit.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_offload_max_bytes: int,
|
||||
cpu_offload_params: set[str] | None = None,
|
||||
):
|
||||
self.cpu_offload_max_bytes = cpu_offload_max_bytes
|
||||
self.cpu_offload_bytes = 0
|
||||
self.cpu_offload_params = cpu_offload_params or set()
|
||||
|
||||
self.pin_memory = (
|
||||
is_pin_memory_available()
|
||||
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
|
||||
)
|
||||
self.uva_offloading = (
|
||||
is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA
|
||||
)
|
||||
|
||||
def wrap_modules(
|
||||
self,
|
||||
modules_generator: Generator[nn.Module, None, None],
|
||||
) -> list[nn.Module]:
|
||||
"""Wrap modules with UVA offloading."""
|
||||
modules = [self._maybe_offload_to_cpu(module) for module in modules_generator]
|
||||
if self.cpu_offload_bytes > 0:
|
||||
logger.info(
|
||||
"Total CPU offloaded parameters: %s",
|
||||
format_gib(self.cpu_offload_bytes),
|
||||
)
|
||||
return modules
|
||||
|
||||
def _maybe_offload_to_cpu(self, module: nn.Module) -> nn.Module:
|
||||
"""Offload module parameters to CPU using UVA if budget allows."""
|
||||
if (params := next(module.parameters(), None)) is None:
|
||||
return module
|
||||
|
||||
device = params.device
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return module
|
||||
|
||||
if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
|
||||
return module
|
||||
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
offloaded_parameters = False
|
||||
for name, p in module.named_parameters():
|
||||
if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
|
||||
# we use per-parameter offloading
|
||||
# one module might have some parameters offloaded and some not
|
||||
break
|
||||
|
||||
if self.cpu_offload_params:
|
||||
# Check if parameter belongs to the offloading set
|
||||
# Add dots here to ensure we match full segments only
|
||||
# e.g., "experts.w2_weight" matches "mlp.experts.w2_weight"
|
||||
# but not "mlp.experts.w2_weight_scale"
|
||||
should_offload = any(
|
||||
f".{param}." in f".{name}." for param in self.cpu_offload_params
|
||||
)
|
||||
if not should_offload:
|
||||
continue
|
||||
|
||||
cpu_data = p.data.to(device="cpu")
|
||||
if self.pin_memory:
|
||||
cpu_data = cpu_data.pin_memory()
|
||||
|
||||
if not self.uva_offloading:
|
||||
p.data = cpu_data
|
||||
else:
|
||||
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
|
||||
p._vllm_is_uva_offloaded = True
|
||||
|
||||
self.cpu_offload_bytes += p.data.numel() * p.data.element_size()
|
||||
offloaded_parameters = True
|
||||
|
||||
if offloaded_parameters and not self.uva_offloading:
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
module.forward = original_forward
|
||||
device_state = {
|
||||
# here we blindly call `to(device)`
|
||||
# if the parameter is already on the device,
|
||||
# it will be a no-op
|
||||
k: v.to(device, non_blocking=True)
|
||||
for k, v in module.state_dict().items()
|
||||
}
|
||||
|
||||
# set `tie_weights=False` as tied weights in original model
|
||||
# become untied when calling .to(device) individually
|
||||
output = functional_call(
|
||||
module,
|
||||
device_state,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
tie_weights=False,
|
||||
)
|
||||
module.forward = forward
|
||||
return output
|
||||
|
||||
module.forward = forward
|
||||
|
||||
return module
|
||||
Reference in New Issue
Block a user