Refactor weight offloading logic (#8521)
This commit is contained in:
@@ -96,6 +96,11 @@ from sglang.srt.model_loader import get_model
|
|||||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.offloader import (
|
||||||
|
create_offloader_from_server_args,
|
||||||
|
get_offloader,
|
||||||
|
set_offloader,
|
||||||
|
)
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
|
|||||||
is_npu,
|
is_npu,
|
||||||
monkey_patch_p2p_access_check,
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
set_cpu_offload_max_bytes,
|
|
||||||
set_cuda_arch,
|
set_cuda_arch,
|
||||||
)
|
)
|
||||||
from sglang.srt.weight_sync.tensor_bucket import (
|
from sglang.srt.weight_sync.tensor_bucket import (
|
||||||
@@ -222,9 +226,6 @@ class ModelRunner:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# CPU offload
|
|
||||||
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
|
||||||
|
|
||||||
# Init OpenMP threads binding for CPU
|
# Init OpenMP threads binding for CPU
|
||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
self.init_threads_binding()
|
self.init_threads_binding()
|
||||||
@@ -232,6 +233,9 @@ class ModelRunner:
|
|||||||
# Get memory before model loading
|
# Get memory before model loading
|
||||||
min_per_gpu_memory = self.init_torch_distributed()
|
min_per_gpu_memory = self.init_torch_distributed()
|
||||||
|
|
||||||
|
# CPU offload
|
||||||
|
set_offloader(create_offloader_from_server_args(server_args))
|
||||||
|
|
||||||
# Update deep gemm configure
|
# Update deep gemm configure
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||||
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
||||||
@@ -690,6 +694,8 @@ class ModelRunner:
|
|||||||
monkey_patch_vllm_parallel_state(reverse=True)
|
monkey_patch_vllm_parallel_state(reverse=True)
|
||||||
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
||||||
|
|
||||||
|
get_offloader().post_init()
|
||||||
|
|
||||||
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||||
if self.server_args.quantization_param_path is not None:
|
if self.server_args.quantization_param_path is not None:
|
||||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||||
|
|||||||
122
python/sglang/srt/offloader.py
Normal file
122
python/sglang/srt/offloader.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import logging
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Callable, Generator, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.func import functional_call
|
||||||
|
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SubmoduleAccessor = Callable[[torch.nn.Module], torch.nn.Module]
|
||||||
|
_WhitelistParamNamesCreator = Callable[[torch.nn.Module], List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOffloader(ABC):
|
||||||
|
def wrap_modules(
|
||||||
|
self,
|
||||||
|
all_modules_generator: Generator[torch.nn.Module, None, None],
|
||||||
|
submodule_accessor: Optional[_SubmoduleAccessor] = None,
|
||||||
|
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
|
||||||
|
):
|
||||||
|
return list(all_modules_generator)
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoopOffloader(BaseOffloader):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# For simplicity use singleton, but can surely support multi instance
|
||||||
|
_instance: Optional[BaseOffloader] = NoopOffloader()
|
||||||
|
|
||||||
|
|
||||||
|
def get_offloader():
|
||||||
|
assert _instance is not None
|
||||||
|
return _instance
|
||||||
|
|
||||||
|
|
||||||
|
def set_offloader(instance: BaseOffloader):
|
||||||
|
global _instance
|
||||||
|
_instance = instance
|
||||||
|
|
||||||
|
|
||||||
|
def create_offloader_from_server_args(server_args: ServerArgs):
|
||||||
|
if server_args.cpu_offload_gb > 0:
|
||||||
|
return OffloaderV1(
|
||||||
|
cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3)
|
||||||
|
)
|
||||||
|
return NoopOffloader()
|
||||||
|
|
||||||
|
|
||||||
|
class OffloaderV1(BaseOffloader):
|
||||||
|
def __init__(self, cpu_offload_max_bytes: int):
|
||||||
|
self._cpu_offload_bytes = 0
|
||||||
|
self._cpu_offload_max_bytes = cpu_offload_max_bytes
|
||||||
|
|
||||||
|
def wrap_modules(
|
||||||
|
self,
|
||||||
|
all_modules_generator: Generator[torch.nn.Module, None, None],
|
||||||
|
submodule_accessor: Optional[_SubmoduleAccessor] = None,
|
||||||
|
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
|
||||||
|
):
|
||||||
|
return [self.maybe_offload_to_cpu(module) for module in all_modules_generator]
|
||||||
|
|
||||||
|
def maybe_offload_to_cpu(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
if (params := next(module.parameters(), None)) is None:
|
||||||
|
return module
|
||||||
|
|
||||||
|
device = params.device
|
||||||
|
|
||||||
|
if device == torch.device("cpu"):
|
||||||
|
return module
|
||||||
|
|
||||||
|
if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
|
||||||
|
return module
|
||||||
|
|
||||||
|
pin_memory = is_pin_memory_available()
|
||||||
|
# offload parameters to CPU
|
||||||
|
# use pin_memory if possible, which helps cudagraph capture speed
|
||||||
|
offloaded_parameters = False
|
||||||
|
for p in module.parameters():
|
||||||
|
if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
|
||||||
|
# we use per-parameter offloading
|
||||||
|
# one module might have some parameters offloaded and some not
|
||||||
|
break
|
||||||
|
|
||||||
|
# `torch.empty_like` does not support `pin_memory` argument
|
||||||
|
cpu_data = torch.empty_strided(
|
||||||
|
size=p.data.size(),
|
||||||
|
stride=p.data.stride(),
|
||||||
|
dtype=p.data.dtype,
|
||||||
|
layout=p.data.layout,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
cpu_data.copy_(p.data)
|
||||||
|
p.data = cpu_data
|
||||||
|
self._cpu_offload_bytes += p.data.numel() * p.data.element_size()
|
||||||
|
offloaded_parameters = True
|
||||||
|
|
||||||
|
if offloaded_parameters:
|
||||||
|
original_forward = module.forward
|
||||||
|
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
module.forward = original_forward
|
||||||
|
device_state = {
|
||||||
|
# here we blindly call `to(device)`
|
||||||
|
# if the parameter is already on the device, it will be a no-op
|
||||||
|
k: v.to(device, non_blocking=True)
|
||||||
|
for k, v in module.state_dict().items()
|
||||||
|
}
|
||||||
|
output = functional_call(module, device_state, args=args, kwargs=kwargs)
|
||||||
|
module.forward = forward
|
||||||
|
return output
|
||||||
|
|
||||||
|
module.forward = forward
|
||||||
|
|
||||||
|
return module
|
||||||
@@ -438,72 +438,6 @@ def is_pin_memory_available() -> bool:
|
|||||||
return torch.cuda.is_available()
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
_CPU_OFFLOAD_BYTES = 0
|
|
||||||
_CPU_OFFLOAD_MAX_BYTES = 0
|
|
||||||
|
|
||||||
|
|
||||||
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
|
||||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
|
||||||
_CPU_OFFLOAD_BYTES = 0
|
|
||||||
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
|
||||||
if (params := next(module.parameters(), None)) is None:
|
|
||||||
return module
|
|
||||||
|
|
||||||
device = params.device
|
|
||||||
if device == torch.device("cpu"):
|
|
||||||
return module
|
|
||||||
|
|
||||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
|
||||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
|
||||||
return module
|
|
||||||
|
|
||||||
pin_memory = is_pin_memory_available()
|
|
||||||
# offload parameters to CPU
|
|
||||||
# use pin_memory if possible, which helps cudagraph capture speed
|
|
||||||
offloaded_parameters = False
|
|
||||||
for p in module.parameters():
|
|
||||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
|
||||||
# we use per-parameter offloading
|
|
||||||
# one module might have some parameters offloaded and some not
|
|
||||||
break
|
|
||||||
|
|
||||||
# `torch.empty_like` does not support `pin_memory` argument
|
|
||||||
cpu_data = torch.empty_strided(
|
|
||||||
size=p.data.size(),
|
|
||||||
stride=p.data.stride(),
|
|
||||||
dtype=p.data.dtype,
|
|
||||||
layout=p.data.layout,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
cpu_data.copy_(p.data)
|
|
||||||
p.data = cpu_data
|
|
||||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
|
||||||
offloaded_parameters = True
|
|
||||||
|
|
||||||
if offloaded_parameters:
|
|
||||||
original_forward = module.forward
|
|
||||||
|
|
||||||
def forward(*args, **kwargs):
|
|
||||||
module.forward = original_forward
|
|
||||||
device_state = {
|
|
||||||
# here we blindly call `to(device)`
|
|
||||||
# if the parameter is already on the device, it will be a no-op
|
|
||||||
k: v.to(device, non_blocking=True)
|
|
||||||
for k, v in module.state_dict().items()
|
|
||||||
}
|
|
||||||
output = functional_call(module, device_state, args=args, kwargs=kwargs)
|
|
||||||
module.forward = forward
|
|
||||||
return output
|
|
||||||
|
|
||||||
module.forward = forward
|
|
||||||
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
class LayerFn(Protocol):
|
class LayerFn(Protocol):
|
||||||
|
|
||||||
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
|
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
|
||||||
@@ -516,11 +450,13 @@ def make_layers(
|
|||||||
pp_size: Optional[int] = None,
|
pp_size: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
return_tuple: bool = False,
|
return_tuple: bool = False,
|
||||||
|
offloader_kwargs: Dict[str, Any] = {},
|
||||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||||
"""Make a list of layers with the given layer function"""
|
"""Make a list of layers with the given layer function"""
|
||||||
# circula imports
|
# circula imports
|
||||||
from sglang.srt.distributed import get_pp_indices
|
from sglang.srt.distributed import get_pp_indices
|
||||||
from sglang.srt.layers.utils import PPMissingLayer
|
from sglang.srt.layers.utils import PPMissingLayer
|
||||||
|
from sglang.srt.offloader import get_offloader
|
||||||
|
|
||||||
assert not pp_size or num_hidden_layers >= pp_size
|
assert not pp_size or num_hidden_layers >= pp_size
|
||||||
start_layer, end_layer = (
|
start_layer, end_layer = (
|
||||||
@@ -534,10 +470,13 @@ def make_layers(
|
|||||||
)
|
)
|
||||||
modules = torch.nn.ModuleList(
|
modules = torch.nn.ModuleList(
|
||||||
[PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
|
[PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
|
||||||
+ [
|
+ get_offloader().wrap_modules(
|
||||||
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
|
(
|
||||||
for idx in range(start_layer, end_layer)
|
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
|
||||||
]
|
for idx in range(start_layer, end_layer)
|
||||||
|
),
|
||||||
|
**offloader_kwargs,
|
||||||
|
)
|
||||||
+ [
|
+ [
|
||||||
PPMissingLayer(return_tuple=return_tuple)
|
PPMissingLayer(return_tuple=return_tuple)
|
||||||
for _ in range(end_layer, num_hidden_layers)
|
for _ in range(end_layer, num_hidden_layers)
|
||||||
|
|||||||
Reference in New Issue
Block a user