From 8c574902106d8568eed74c9c20eab66bf1b1f16c Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Fri, 3 Oct 2025 04:48:19 -0400 Subject: [PATCH] [Feature] Option to save model weights to CPU when memory saver mode is enabled (#10873) Co-authored-by: molocule <34072934+molocule@users.noreply.github.com> --- docs/advanced_features/server_arguments.md | 1 + python/pyproject.toml | 2 +- python/pyproject_other.toml | 2 +- .../sglang/srt/model_executor/model_runner.py | 18 +++--- python/sglang/srt/server_args.py | 6 ++ .../sglang/srt/torch_memory_saver_adapter.py | 12 ++-- test/srt/test_release_memory_occupation.py | 59 ++++++++++++++++++- 7 files changed, 78 insertions(+), 22 deletions(-) diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 4e10a6402..e76533df7 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -305,6 +305,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--num-continuous-decode-steps` | Run multiple continuous decoding steps to reduce scheduling overhead. This can potentially increase throughput but may also increase time-to-first-token latency. The default value is 1, meaning only run one decoding step at a time. | 1 | | `--delete-ckpt-after-loading` | Delete the model checkpoint after loading the model. | False | | `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False | +| `--enable-weights-cpu-backup` | Save model weights to CPU memory during release_weights_occupation and resume_weights_occupation | False | | `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False | | `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False | | `--flashinfer-mla-disable-ragged` | Disable ragged processing in Flashinfer MLA. | False | diff --git a/python/pyproject.toml b/python/pyproject.toml index 6bc4eceb8..d112583bf 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -58,7 +58,7 @@ dependencies = [ "tiktoken", "timm==1.0.16", "torch==2.8.0", - "torch_memory_saver==0.0.8", + "torch_memory_saver==0.0.9rc1", "torchao==0.9.0", "torchaudio==2.8.0", "torchvision", diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index dd282dc5b..415baf41e 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -110,7 +110,7 @@ srt_hpu = ["sglang[runtime_common]"] openai = ["openai==1.99.1", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] -torch_memory_saver = ["torch_memory_saver==0.0.8"] +torch_memory_saver = ["torch_memory_saver==0.0.9rc1"] decord = ["decord"] test = [ "accelerate", diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 42dfaecf7..a73a6be3b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -25,9 +25,7 @@ import time from collections import defaultdict from dataclasses import dataclass from typing import List, Optional, Tuple, Union -from urllib.parse import urlparse -import requests import torch import torch.distributed as dist @@ -36,7 +34,6 @@ from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp -from sglang.srt.connector import ConnectorType from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.distributed import ( get_pp_group, @@ -132,7 +129,6 @@ from sglang.srt.utils import ( get_bool_env_var, get_cpu_ids_by_node, init_custom_process_group, - is_blackwell, is_fa3_default_architecture, is_flashinfer_available, is_hip, @@ -143,7 +139,6 @@ from sglang.srt.utils import ( log_info_on_rank0, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, - parse_connector_type, set_cuda_arch, ) from sglang.srt.weight_sync.tensor_bucket import ( @@ -616,7 +611,7 @@ class ModelRunner: server_args.hicache_io_backend = "direct" logger.warning( "FlashAttention3 decode backend is not compatible with hierarchical cache. " - f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes." + "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes." ) def init_torch_distributed(self): @@ -778,7 +773,10 @@ class ModelRunner: monkey_patch_vllm_parallel_state() monkey_patch_isinstance_for_vllm_base_layer() - with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS): + with self.memory_saver_adapter.region( + GPU_MEMORY_TYPE_WEIGHTS, + enable_cpu_backup=self.server_args.enable_weights_cpu_backup, + ): self.model = get_model( model_config=self.model_config, load_config=self.load_config, @@ -1106,7 +1104,7 @@ class ModelRunner: handle.wait() self.model.load_weights(weights) - return True, f"Succeeded to update parameter online." + return True, "Succeeded to update parameter online." except Exception as e: error_msg = ( @@ -1749,8 +1747,8 @@ class ModelRunner: f"prefill_backend={self.prefill_attention_backend_str}." ) logger.warning( - f"Warning: Attention backend specified by --attention-backend or default backend might be overridden." - f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem." + "Warning: Attention backend specified by --attention-backend or default backend might be overridden." + "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem." ) else: attn_backend = self._get_attention_backend_from_str( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 029502f5c..f9ff382a1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -400,6 +400,7 @@ class ServerArgs: num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False + enable_weights_cpu_backup: bool = False allow_auto_truncate: bool = False enable_custom_logit_processor: bool = False flashinfer_mla_disable_ragged: bool = False @@ -2541,6 +2542,11 @@ class ServerArgs: action="store_true", help="Allow saving memory using release_memory_occupation and resume_memory_occupation", ) + parser.add_argument( + "--enable-weights-cpu-backup", + action="store_true", + help="Save model weights to CPU memory during release_weights_occupation and resume_weights_occupation", + ) parser.add_argument( "--allow-auto-truncate", action="store_true", diff --git a/python/sglang/srt/torch_memory_saver_adapter.py b/python/sglang/srt/torch_memory_saver_adapter.py index a46151782..d00c97c5d 100644 --- a/python/sglang/srt/torch_memory_saver_adapter.py +++ b/python/sglang/srt/torch_memory_saver_adapter.py @@ -1,8 +1,6 @@ import logging -import threading -import time from abc import ABC -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager try: import torch_memory_saver @@ -40,7 +38,7 @@ class TorchMemorySaverAdapter(ABC): def configure_subprocess(self): raise NotImplementedError - def region(self, tag: str): + def region(self, tag: str, enable_cpu_backup: bool = False): raise NotImplementedError def pause(self, tag: str): @@ -60,8 +58,8 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): def configure_subprocess(self): return torch_memory_saver.configure_subprocess() - def region(self, tag: str): - return _memory_saver.region(tag=tag) + def region(self, tag: str, enable_cpu_backup: bool = False): + return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup) def pause(self, tag: str): return _memory_saver.pause(tag=tag) @@ -80,7 +78,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): yield @contextmanager - def region(self, tag: str): + def region(self, tag: str, enable_cpu_backup: bool = False): yield def pause(self, tag: str): diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py index 35be029df..071b1694e 100644 --- a/test/srt/test_release_memory_occupation.py +++ b/test/srt/test_release_memory_occupation.py @@ -25,8 +25,6 @@ configurations (tp=1, tp=2) to ensure proper memory management in distributed se data parallel size, we test it in verl. """ -import gc -import os import time import unittest @@ -52,7 +50,14 @@ def get_gpu_memory_gb(): class TestReleaseMemoryOccupation(CustomTestCase): - def _setup_engine(self, model_name, mem_fraction_static=0.8, tp_size=1, ep_size=1): + def _setup_engine( + self, + model_name, + mem_fraction_static=0.8, + tp_size=1, + ep_size=1, + enable_weights_cpu_backup=False, + ): """Common setup for engine and HF model.""" engine = sgl.Engine( model_path=model_name, @@ -61,6 +66,7 @@ class TestReleaseMemoryOccupation(CustomTestCase): mem_fraction_static=mem_fraction_static, tp_size=tp_size, ep_size=ep_size, + enable_weights_cpu_backup=enable_weights_cpu_backup, # disable_cuda_graph=True, # for debugging only ) @@ -153,6 +159,53 @@ class TestReleaseMemoryOccupation(CustomTestCase): self.assertEqual(outputs, params["expect_output_after_update_weights"]) engine.shutdown() + def test_release_and_resume_occupation_with_weights_cpu_backup(self): + # Test release and resume occupation with weights CPU backup + model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + print("Testing test_release_and_resume_occupation_with_weights_cpu_backup") + engine = self._setup_engine( + model_name=model_name, + mem_fraction_static=0.6, + enable_weights_cpu_backup=True, + ) + params = self._common_test_params() + + self._test_initial_generation( + engine, + params["prompt"], + params["sampling_params"], + params["expect_output_before_update_weights"], + ) + + t = time.perf_counter() + gpu_memory_usage_before_release = get_gpu_memory_gb() + engine.release_memory_occupation() + gpu_memory_usage_after_release = get_gpu_memory_gb() + + self.assertLess( + gpu_memory_usage_after_release, + gpu_memory_usage_before_release, + ) + + print( + f"Release took {time.perf_counter() - t:.2f}s, memory: {gpu_memory_usage_before_release:.1f} GB → {gpu_memory_usage_after_release:.1f} GB" + ) + + if _DEBUG_EXTRA: + time.sleep(3) + + t = time.perf_counter() + engine.resume_memory_occupation() + print( + f"Resume took {time.perf_counter() - t:.2f}s, memory: {get_gpu_memory_gb():.1f} GB" + ) + + print("generate post resume") + outputs = engine.generate(params["prompt"], params["sampling_params"])["text"] + self.assertEqual(outputs, params["expect_output_before_update_weights"]) + engine.shutdown() + def test_multi_stage_release_and_resume(self): # With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST