From d98fa1e93dc9af557c2dd0aa80f8ba80a2fe65e5 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Sat, 23 Nov 2024 08:23:53 +0200 Subject: [PATCH] Add simple CPU offloading support. (#2081) --- .../sglang/srt/model_executor/model_runner.py | 5 +- python/sglang/srt/models/gemma2.py | 15 ++- python/sglang/srt/models/llama.py | 15 +-- python/sglang/srt/models/olmo.py | 13 ++- python/sglang/srt/models/olmoe.py | 13 ++- python/sglang/srt/models/qwen2.py | 13 ++- python/sglang/srt/server_args.py | 8 ++ python/sglang/srt/utils.py | 91 ++++++++++++++++++- test/srt/test_srt_engine.py | 30 ++++++ 9 files changed, 174 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3d5e450a4..3ba311b8c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -61,6 +61,7 @@ from sglang.srt.utils import ( is_hip, monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, + set_cpu_offload_max_bytes, ) logger = logging.getLogger(__name__) @@ -145,7 +146,9 @@ class ModelRunner: } ) - # Init componnets + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + + # Init components min_per_gpu_memory = self.init_torch_distributed() self.sampler = Sampler() self.load_model() diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index d5972c110..0fa6a5393 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -38,6 +38,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers # Aligned with HF's implementation, using sliding window inclusive with the last token @@ -267,11 +268,15 @@ class Gemma2Model(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [ - Gemma2DecoderLayer(layer_id, config, cache_config, quant_config) - for layer_id in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma2DecoderLayer( + layer_id=idx, + config=config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix="", ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 5ddc28ebd..7e9fd0f72 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers class LlamaMLP(nn.Module): @@ -255,14 +256,14 @@ class LlamaModel(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [ - LlamaDecoderLayer( - config, i, quant_config=quant_config, prefix=f"model.layers.{i}" - ) - for i in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: LlamaDecoderLayer( + config=config, quant_config=quant_config, layer_id=idx, prefix=prefix + ), + prefix="model.layers", ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 729039f93..80fd64a53 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers class OlmoAttention(nn.Module): @@ -220,11 +221,13 @@ class OlmoModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size ) - self.layers = nn.ModuleList( - [ - OlmoDecoderLayer(config, layer_id, quant_config) - for layer_id in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: OlmoDecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + ), ) self.norm = nn.LayerNorm( config.hidden_size, elementwise_affine=False, bias=False diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 99baa894d..d1e8e6027 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -48,6 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers class OlmoeMoE(nn.Module): @@ -261,11 +262,13 @@ class OlmoeModel(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [ - OlmoeDecoderLayer(config, layer_id, quant_config=quant_config) - for layer_id in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: OlmoeDecoderLayer( + config=config, + quant_config=quant_config, + layer_id=idx, + ), ) self.norm = RMSNorm(config.hidden_size, eps=1e-5) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 308e9008a..634ce1cf1 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers Qwen2Config = None @@ -230,11 +231,13 @@ class Qwen2Model(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [ - Qwen2DecoderLayer(config, i, quant_config=quant_config) - for i in range(config.num_hidden_layers) - ] + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Qwen2DecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + ), ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7d2842e34..5a5cca918 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -62,6 +62,7 @@ class ServerArgs: max_prefill_tokens: int = 16384 schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 + cpu_offload_gb: int = 0 # Other runtime options tp_size: int = 1 @@ -373,6 +374,13 @@ class ServerArgs: help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) + parser.add_argument( + "--cpu-offload-gb", + type=int, + default=ServerArgs.cpu_offload_gb, + help="How many GBs of RAM to reserve for CPU offloading", + ) + # Other runtime options parser.add_argument( "--tensor-parallel-size", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e7f5392b9..fcba31a56 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -31,7 +31,7 @@ import time import warnings from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union import numpy as np import psutil @@ -44,6 +44,7 @@ from fastapi.responses import ORJSONResponse from packaging import version as pkg_version from starlette.routing import Mount from torch import nn +from torch.func import functional_call from torch.profiler import ProfilerActivity, profile, record_function from triton.runtime.cache import ( FileCacheManager, @@ -190,6 +191,94 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): return free_gpu_memory / (1 << 30) +def is_pin_memory_available() -> bool: + return torch.cuda.is_available() + + +_CPU_OFFLOAD_BYTES = 0 +_CPU_OFFLOAD_MAX_BYTES = 0 + + +def set_cpu_offload_max_bytes(max_bytes: int) -> None: + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + _CPU_OFFLOAD_BYTES = 0 + _CPU_OFFLOAD_MAX_BYTES = max_bytes + + +def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: + device = next(module.parameters()).device + + if device == torch.device("cpu"): + return module + + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + return module + + pin_memory = is_pin_memory_available() + # offload parameters to CPU + # use pin_memory if possible, which helps cudagraph capture speed + offloaded_parameters = False + for p in module.parameters(): + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + # we use per-parameter offloading + # one module might have some parameters offloaded and some not + break + + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() + offloaded_parameters = True + + if offloaded_parameters: + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in module.state_dict().items() + } + output = functional_call(module, device_state, args=args, kwargs=kwargs) + module.forward = forward + return output + + module.forward = forward + + return module + + +class LayerFn(Protocol): + + def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... + + +def make_layers( + num_hidden_layers: int, + layer_fn: LayerFn, + prefix: str = "", +) -> Tuple[int, int, torch.nn.ModuleList]: + """Make a list of layers with the given layer function""" + modules = torch.nn.ModuleList( + [ + maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}")) + for idx in range(num_hidden_layers) + ] + ) + return modules + + def set_random_seed(seed: int) -> None: """Set the random seed for all libraries.""" random.seed(seed) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 5d7f95440..f0dfa8f85 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -160,6 +160,36 @@ class TestSRTEngine(unittest.TestCase): result = throughput_test(server_args=server_args, bench_args=bench_args) self.assertGreater(result["total_throughput"], 3500) + def test_8_engine_cpu_offload(self): + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + engine = sgl.Engine( + model_path=model_path, + random_seed=42, + max_total_tokens=128, + ) + out1 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + engine = sgl.Engine( + model_path=model_path, + random_seed=42, + max_total_tokens=128, + cpu_offload_gb=3, + ) + out2 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + print("==== Answer 1 ====") + print(out1) + + print("==== Answer 2 ====") + print(out2) + self.assertEqual(out1, out2) + if __name__ == "__main__": unittest.main()