diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a20571253..c43c502da 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -96,6 +96,11 @@ from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.offloader import ( + create_offloader_from_server_args, + get_offloader, + set_offloader, +) from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs @@ -118,7 +123,6 @@ from sglang.srt.utils import ( is_npu, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, - set_cpu_offload_max_bytes, set_cuda_arch, ) from sglang.srt.weight_sync.tensor_bucket import ( @@ -222,9 +226,6 @@ class ModelRunner: } ) - # CPU offload - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) - # Init OpenMP threads binding for CPU if self.device == "cpu": self.init_threads_binding() @@ -232,6 +233,9 @@ class ModelRunner: # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + # CPU offload + set_offloader(create_offloader_from_server_args(server_args)) + # Update deep gemm configure if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args) @@ -690,6 +694,8 @@ class ModelRunner: monkey_patch_vllm_parallel_state(reverse=True) monkey_patch_isinstance_for_vllm_base_layer(reverse=True) + get_offloader().post_init() + if self.server_args.kv_cache_dtype == "fp8_e4m3": if self.server_args.quantization_param_path is not None: if callable(getattr(self.model, "load_kv_cache_scales", None)): diff --git a/python/sglang/srt/offloader.py b/python/sglang/srt/offloader.py new file mode 100644 index 000000000..f7bf4082b --- /dev/null +++ b/python/sglang/srt/offloader.py @@ -0,0 +1,122 @@ +import logging +from abc import ABC +from typing import Callable, Generator, List, Optional + +import torch +from torch.func import functional_call + +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import is_pin_memory_available + +logger = logging.getLogger(__name__) + +_SubmoduleAccessor = Callable[[torch.nn.Module], torch.nn.Module] +_WhitelistParamNamesCreator = Callable[[torch.nn.Module], List[str]] + + +class BaseOffloader(ABC): + def wrap_modules( + self, + all_modules_generator: Generator[torch.nn.Module, None, None], + submodule_accessor: Optional[_SubmoduleAccessor] = None, + whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None, + ): + return list(all_modules_generator) + + def post_init(self): + pass + + +class NoopOffloader(BaseOffloader): + pass + + +# For simplicity use singleton, but can surely support multi instance +_instance: Optional[BaseOffloader] = NoopOffloader() + + +def get_offloader(): + assert _instance is not None + return _instance + + +def set_offloader(instance: BaseOffloader): + global _instance + _instance = instance + + +def create_offloader_from_server_args(server_args: ServerArgs): + if server_args.cpu_offload_gb > 0: + return OffloaderV1( + cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3) + ) + return NoopOffloader() + + +class OffloaderV1(BaseOffloader): + def __init__(self, cpu_offload_max_bytes: int): + self._cpu_offload_bytes = 0 + self._cpu_offload_max_bytes = cpu_offload_max_bytes + + def wrap_modules( + self, + all_modules_generator: Generator[torch.nn.Module, None, None], + submodule_accessor: Optional[_SubmoduleAccessor] = None, + whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None, + ): + return [self.maybe_offload_to_cpu(module) for module in all_modules_generator] + + def maybe_offload_to_cpu(self, module: torch.nn.Module) -> torch.nn.Module: + if (params := next(module.parameters(), None)) is None: + return module + + device = params.device + + if device == torch.device("cpu"): + return module + + if self._cpu_offload_bytes >= self._cpu_offload_max_bytes: + return module + + pin_memory = is_pin_memory_available() + # offload parameters to CPU + # use pin_memory if possible, which helps cudagraph capture speed + offloaded_parameters = False + for p in module.parameters(): + if self._cpu_offload_bytes >= self._cpu_offload_max_bytes: + # we use per-parameter offloading + # one module might have some parameters offloaded and some not + break + + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + self._cpu_offload_bytes += p.data.numel() * p.data.element_size() + offloaded_parameters = True + + if offloaded_parameters: + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in module.state_dict().items() + } + output = functional_call(module, device_state, args=args, kwargs=kwargs) + module.forward = forward + return output + + module.forward = forward + + return module diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0318f3bd4..62c1c8532 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -438,72 +438,6 @@ def is_pin_memory_available() -> bool: return torch.cuda.is_available() -_CPU_OFFLOAD_BYTES = 0 -_CPU_OFFLOAD_MAX_BYTES = 0 - - -def set_cpu_offload_max_bytes(max_bytes: int) -> None: - global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES - _CPU_OFFLOAD_BYTES = 0 - _CPU_OFFLOAD_MAX_BYTES = max_bytes - - -def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: - if (params := next(module.parameters(), None)) is None: - return module - - device = params.device - if device == torch.device("cpu"): - return module - - global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES - if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: - return module - - pin_memory = is_pin_memory_available() - # offload parameters to CPU - # use pin_memory if possible, which helps cudagraph capture speed - offloaded_parameters = False - for p in module.parameters(): - if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: - # we use per-parameter offloading - # one module might have some parameters offloaded and some not - break - - # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided( - size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device="cpu", - pin_memory=pin_memory, - ) - cpu_data.copy_(p.data) - p.data = cpu_data - _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() - offloaded_parameters = True - - if offloaded_parameters: - original_forward = module.forward - - def forward(*args, **kwargs): - module.forward = original_forward - device_state = { - # here we blindly call `to(device)` - # if the parameter is already on the device, it will be a no-op - k: v.to(device, non_blocking=True) - for k, v in module.state_dict().items() - } - output = functional_call(module, device_state, args=args, kwargs=kwargs) - module.forward = forward - return output - - module.forward = forward - - return module - - class LayerFn(Protocol): def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... @@ -516,11 +450,13 @@ def make_layers( pp_size: Optional[int] = None, prefix: str = "", return_tuple: bool = False, + offloader_kwargs: Dict[str, Any] = {}, ) -> Tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function""" # circula imports from sglang.srt.distributed import get_pp_indices from sglang.srt.layers.utils import PPMissingLayer + from sglang.srt.offloader import get_offloader assert not pp_size or num_hidden_layers >= pp_size start_layer, end_layer = ( @@ -534,10 +470,13 @@ def make_layers( ) modules = torch.nn.ModuleList( [PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)] - + [ - maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix))) - for idx in range(start_layer, end_layer) - ] + + get_offloader().wrap_modules( + ( + layer_fn(idx=idx, prefix=add_prefix(idx, prefix)) + for idx in range(start_layer, end_layer) + ), + **offloader_kwargs, + ) + [ PPMissingLayer(return_tuple=return_tuple) for _ in range(end_layer, num_hidden_layers)