Refactor weight offloading logic (#8521)
This commit is contained in:
@@ -96,6 +96,11 @@ from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.offloader import (
|
||||
create_offloader_from_server_args,
|
||||
get_offloader,
|
||||
set_offloader,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
|
||||
is_npu,
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.srt.weight_sync.tensor_bucket import (
|
||||
@@ -222,9 +226,6 @@ class ModelRunner:
|
||||
}
|
||||
)
|
||||
|
||||
# CPU offload
|
||||
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
||||
|
||||
# Init OpenMP threads binding for CPU
|
||||
if self.device == "cpu":
|
||||
self.init_threads_binding()
|
||||
@@ -232,6 +233,9 @@ class ModelRunner:
|
||||
# Get memory before model loading
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
# CPU offload
|
||||
set_offloader(create_offloader_from_server_args(server_args))
|
||||
|
||||
# Update deep gemm configure
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
||||
@@ -690,6 +694,8 @@ class ModelRunner:
|
||||
monkey_patch_vllm_parallel_state(reverse=True)
|
||||
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
||||
|
||||
get_offloader().post_init()
|
||||
|
||||
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if self.server_args.quantization_param_path is not None:
|
||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||
|
||||
122
python/sglang/srt/offloader.py
Normal file
122
python/sglang/srt/offloader.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.func import functional_call
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import is_pin_memory_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SubmoduleAccessor = Callable[[torch.nn.Module], torch.nn.Module]
|
||||
_WhitelistParamNamesCreator = Callable[[torch.nn.Module], List[str]]
|
||||
|
||||
|
||||
class BaseOffloader(ABC):
|
||||
def wrap_modules(
|
||||
self,
|
||||
all_modules_generator: Generator[torch.nn.Module, None, None],
|
||||
submodule_accessor: Optional[_SubmoduleAccessor] = None,
|
||||
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
|
||||
):
|
||||
return list(all_modules_generator)
|
||||
|
||||
def post_init(self):
|
||||
pass
|
||||
|
||||
|
||||
class NoopOffloader(BaseOffloader):
|
||||
pass
|
||||
|
||||
|
||||
# For simplicity use singleton, but can surely support multi instance
|
||||
_instance: Optional[BaseOffloader] = NoopOffloader()
|
||||
|
||||
|
||||
def get_offloader():
|
||||
assert _instance is not None
|
||||
return _instance
|
||||
|
||||
|
||||
def set_offloader(instance: BaseOffloader):
|
||||
global _instance
|
||||
_instance = instance
|
||||
|
||||
|
||||
def create_offloader_from_server_args(server_args: ServerArgs):
|
||||
if server_args.cpu_offload_gb > 0:
|
||||
return OffloaderV1(
|
||||
cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3)
|
||||
)
|
||||
return NoopOffloader()
|
||||
|
||||
|
||||
class OffloaderV1(BaseOffloader):
|
||||
def __init__(self, cpu_offload_max_bytes: int):
|
||||
self._cpu_offload_bytes = 0
|
||||
self._cpu_offload_max_bytes = cpu_offload_max_bytes
|
||||
|
||||
def wrap_modules(
|
||||
self,
|
||||
all_modules_generator: Generator[torch.nn.Module, None, None],
|
||||
submodule_accessor: Optional[_SubmoduleAccessor] = None,
|
||||
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
|
||||
):
|
||||
return [self.maybe_offload_to_cpu(module) for module in all_modules_generator]
|
||||
|
||||
def maybe_offload_to_cpu(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if (params := next(module.parameters(), None)) is None:
|
||||
return module
|
||||
|
||||
device = params.device
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return module
|
||||
|
||||
if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
|
||||
return module
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
offloaded_parameters = False
|
||||
for p in module.parameters():
|
||||
if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
|
||||
# we use per-parameter offloading
|
||||
# one module might have some parameters offloaded and some not
|
||||
break
|
||||
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
self._cpu_offload_bytes += p.data.numel() * p.data.element_size()
|
||||
offloaded_parameters = True
|
||||
|
||||
if offloaded_parameters:
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
module.forward = original_forward
|
||||
device_state = {
|
||||
# here we blindly call `to(device)`
|
||||
# if the parameter is already on the device, it will be a no-op
|
||||
k: v.to(device, non_blocking=True)
|
||||
for k, v in module.state_dict().items()
|
||||
}
|
||||
output = functional_call(module, device_state, args=args, kwargs=kwargs)
|
||||
module.forward = forward
|
||||
return output
|
||||
|
||||
module.forward = forward
|
||||
|
||||
return module
|
||||
@@ -438,72 +438,6 @@ def is_pin_memory_available() -> bool:
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = 0
|
||||
|
||||
|
||||
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
||||
|
||||
|
||||
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
if (params := next(module.parameters(), None)) is None:
|
||||
return module
|
||||
|
||||
device = params.device
|
||||
if device == torch.device("cpu"):
|
||||
return module
|
||||
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
return module
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
offloaded_parameters = False
|
||||
for p in module.parameters():
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
# we use per-parameter offloading
|
||||
# one module might have some parameters offloaded and some not
|
||||
break
|
||||
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||
offloaded_parameters = True
|
||||
|
||||
if offloaded_parameters:
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
module.forward = original_forward
|
||||
device_state = {
|
||||
# here we blindly call `to(device)`
|
||||
# if the parameter is already on the device, it will be a no-op
|
||||
k: v.to(device, non_blocking=True)
|
||||
for k, v in module.state_dict().items()
|
||||
}
|
||||
output = functional_call(module, device_state, args=args, kwargs=kwargs)
|
||||
module.forward = forward
|
||||
return output
|
||||
|
||||
module.forward = forward
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
|
||||
@@ -516,11 +450,13 @@ def make_layers(
|
||||
pp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
return_tuple: bool = False,
|
||||
offloader_kwargs: Dict[str, Any] = {},
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function"""
|
||||
# circula imports
|
||||
from sglang.srt.distributed import get_pp_indices
|
||||
from sglang.srt.layers.utils import PPMissingLayer
|
||||
from sglang.srt.offloader import get_offloader
|
||||
|
||||
assert not pp_size or num_hidden_layers >= pp_size
|
||||
start_layer, end_layer = (
|
||||
@@ -534,10 +470,13 @@ def make_layers(
|
||||
)
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
|
||||
+ [
|
||||
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
|
||||
for idx in range(start_layer, end_layer)
|
||||
]
|
||||
+ get_offloader().wrap_modules(
|
||||
(
|
||||
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
|
||||
for idx in range(start_layer, end_layer)
|
||||
),
|
||||
**offloader_kwargs,
|
||||
)
|
||||
+ [
|
||||
PPMissingLayer(return_tuple=return_tuple)
|
||||
for _ in range(end_layer, num_hidden_layers)
|
||||
|
||||
Reference in New Issue
Block a user