[Feature] Option to save model weights to CPU when memory saver mode is enabled (#10873)
Co-authored-by: molocule <34072934+molocule@users.noreply.github.com>
This commit is contained in:
@@ -305,6 +305,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--num-continuous-decode-steps` | Run multiple continuous decoding steps to reduce scheduling overhead. This can potentially increase throughput but may also increase time-to-first-token latency. The default value is 1, meaning only run one decoding step at a time. | 1 |
|
| `--num-continuous-decode-steps` | Run multiple continuous decoding steps to reduce scheduling overhead. This can potentially increase throughput but may also increase time-to-first-token latency. The default value is 1, meaning only run one decoding step at a time. | 1 |
|
||||||
| `--delete-ckpt-after-loading` | Delete the model checkpoint after loading the model. | False |
|
| `--delete-ckpt-after-loading` | Delete the model checkpoint after loading the model. | False |
|
||||||
| `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False |
|
| `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False |
|
||||||
|
| `--enable-weights-cpu-backup` | Save model weights to CPU memory during release_weights_occupation and resume_weights_occupation | False |
|
||||||
| `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False |
|
| `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False |
|
||||||
| `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False |
|
| `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False |
|
||||||
| `--flashinfer-mla-disable-ragged` | Disable ragged processing in Flashinfer MLA. | False |
|
| `--flashinfer-mla-disable-ragged` | Disable ragged processing in Flashinfer MLA. | False |
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ dependencies = [
|
|||||||
"tiktoken",
|
"tiktoken",
|
||||||
"timm==1.0.16",
|
"timm==1.0.16",
|
||||||
"torch==2.8.0",
|
"torch==2.8.0",
|
||||||
"torch_memory_saver==0.0.8",
|
"torch_memory_saver==0.0.9rc1",
|
||||||
"torchao==0.9.0",
|
"torchao==0.9.0",
|
||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ srt_hpu = ["sglang[runtime_common]"]
|
|||||||
openai = ["openai==1.99.1", "tiktoken"]
|
openai = ["openai==1.99.1", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0"]
|
anthropic = ["anthropic>=0.20.0"]
|
||||||
litellm = ["litellm>=1.0.0"]
|
litellm = ["litellm>=1.0.0"]
|
||||||
torch_memory_saver = ["torch_memory_saver==0.0.8"]
|
torch_memory_saver = ["torch_memory_saver==0.0.9rc1"]
|
||||||
decord = ["decord"]
|
decord = ["decord"]
|
||||||
test = [
|
test = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
|
|||||||
@@ -25,9 +25,7 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
@@ -36,7 +34,6 @@ from sglang.srt.configs.device_config import DeviceConfig
|
|||||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
||||||
from sglang.srt.connector import ConnectorType
|
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
@@ -132,7 +129,6 @@ from sglang.srt.utils import (
|
|||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_cpu_ids_by_node,
|
get_cpu_ids_by_node,
|
||||||
init_custom_process_group,
|
init_custom_process_group,
|
||||||
is_blackwell,
|
|
||||||
is_fa3_default_architecture,
|
is_fa3_default_architecture,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
@@ -143,7 +139,6 @@ from sglang.srt.utils import (
|
|||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
monkey_patch_p2p_access_check,
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
parse_connector_type,
|
|
||||||
set_cuda_arch,
|
set_cuda_arch,
|
||||||
)
|
)
|
||||||
from sglang.srt.weight_sync.tensor_bucket import (
|
from sglang.srt.weight_sync.tensor_bucket import (
|
||||||
@@ -616,7 +611,7 @@ class ModelRunner:
|
|||||||
server_args.hicache_io_backend = "direct"
|
server_args.hicache_io_backend = "direct"
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
||||||
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
@@ -778,7 +773,10 @@ class ModelRunner:
|
|||||||
monkey_patch_vllm_parallel_state()
|
monkey_patch_vllm_parallel_state()
|
||||||
monkey_patch_isinstance_for_vllm_base_layer()
|
monkey_patch_isinstance_for_vllm_base_layer()
|
||||||
|
|
||||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
with self.memory_saver_adapter.region(
|
||||||
|
GPU_MEMORY_TYPE_WEIGHTS,
|
||||||
|
enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
|
||||||
|
):
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
@@ -1106,7 +1104,7 @@ class ModelRunner:
|
|||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
self.model.load_weights(weights)
|
self.model.load_weights(weights)
|
||||||
return True, f"Succeeded to update parameter online."
|
return True, "Succeeded to update parameter online."
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
@@ -1749,8 +1747,8 @@ class ModelRunner:
|
|||||||
f"prefill_backend={self.prefill_attention_backend_str}."
|
f"prefill_backend={self.prefill_attention_backend_str}."
|
||||||
)
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
||||||
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_backend = self._get_attention_backend_from_str(
|
attn_backend = self._get_attention_backend_from_str(
|
||||||
|
|||||||
@@ -400,6 +400,7 @@ class ServerArgs:
|
|||||||
num_continuous_decode_steps: int = 1
|
num_continuous_decode_steps: int = 1
|
||||||
delete_ckpt_after_loading: bool = False
|
delete_ckpt_after_loading: bool = False
|
||||||
enable_memory_saver: bool = False
|
enable_memory_saver: bool = False
|
||||||
|
enable_weights_cpu_backup: bool = False
|
||||||
allow_auto_truncate: bool = False
|
allow_auto_truncate: bool = False
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
flashinfer_mla_disable_ragged: bool = False
|
flashinfer_mla_disable_ragged: bool = False
|
||||||
@@ -2541,6 +2542,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-weights-cpu-backup",
|
||||||
|
action="store_true",
|
||||||
|
help="Save model weights to CPU memory during release_weights_occupation and resume_weights_occupation",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--allow-auto-truncate",
|
"--allow-auto-truncate",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch_memory_saver
|
import torch_memory_saver
|
||||||
@@ -40,7 +38,7 @@ class TorchMemorySaverAdapter(ABC):
|
|||||||
def configure_subprocess(self):
|
def configure_subprocess(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def region(self, tag: str):
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def pause(self, tag: str):
|
def pause(self, tag: str):
|
||||||
@@ -60,8 +58,8 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
|||||||
def configure_subprocess(self):
|
def configure_subprocess(self):
|
||||||
return torch_memory_saver.configure_subprocess()
|
return torch_memory_saver.configure_subprocess()
|
||||||
|
|
||||||
def region(self, tag: str):
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
||||||
return _memory_saver.region(tag=tag)
|
return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
|
||||||
|
|
||||||
def pause(self, tag: str):
|
def pause(self, tag: str):
|
||||||
return _memory_saver.pause(tag=tag)
|
return _memory_saver.pause(tag=tag)
|
||||||
@@ -80,7 +78,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def region(self, tag: str):
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def pause(self, tag: str):
|
def pause(self, tag: str):
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ configurations (tp=1, tp=2) to ensure proper memory management in distributed se
|
|||||||
data parallel size, we test it in verl.
|
data parallel size, we test it in verl.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -52,7 +50,14 @@ def get_gpu_memory_gb():
|
|||||||
|
|
||||||
|
|
||||||
class TestReleaseMemoryOccupation(CustomTestCase):
|
class TestReleaseMemoryOccupation(CustomTestCase):
|
||||||
def _setup_engine(self, model_name, mem_fraction_static=0.8, tp_size=1, ep_size=1):
|
def _setup_engine(
|
||||||
|
self,
|
||||||
|
model_name,
|
||||||
|
mem_fraction_static=0.8,
|
||||||
|
tp_size=1,
|
||||||
|
ep_size=1,
|
||||||
|
enable_weights_cpu_backup=False,
|
||||||
|
):
|
||||||
"""Common setup for engine and HF model."""
|
"""Common setup for engine and HF model."""
|
||||||
engine = sgl.Engine(
|
engine = sgl.Engine(
|
||||||
model_path=model_name,
|
model_path=model_name,
|
||||||
@@ -61,6 +66,7 @@ class TestReleaseMemoryOccupation(CustomTestCase):
|
|||||||
mem_fraction_static=mem_fraction_static,
|
mem_fraction_static=mem_fraction_static,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
|
enable_weights_cpu_backup=enable_weights_cpu_backup,
|
||||||
# disable_cuda_graph=True, # for debugging only
|
# disable_cuda_graph=True, # for debugging only
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -153,6 +159,53 @@ class TestReleaseMemoryOccupation(CustomTestCase):
|
|||||||
self.assertEqual(outputs, params["expect_output_after_update_weights"])
|
self.assertEqual(outputs, params["expect_output_after_update_weights"])
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
|
def test_release_and_resume_occupation_with_weights_cpu_backup(self):
|
||||||
|
# Test release and resume occupation with weights CPU backup
|
||||||
|
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
print("Testing test_release_and_resume_occupation_with_weights_cpu_backup")
|
||||||
|
engine = self._setup_engine(
|
||||||
|
model_name=model_name,
|
||||||
|
mem_fraction_static=0.6,
|
||||||
|
enable_weights_cpu_backup=True,
|
||||||
|
)
|
||||||
|
params = self._common_test_params()
|
||||||
|
|
||||||
|
self._test_initial_generation(
|
||||||
|
engine,
|
||||||
|
params["prompt"],
|
||||||
|
params["sampling_params"],
|
||||||
|
params["expect_output_before_update_weights"],
|
||||||
|
)
|
||||||
|
|
||||||
|
t = time.perf_counter()
|
||||||
|
gpu_memory_usage_before_release = get_gpu_memory_gb()
|
||||||
|
engine.release_memory_occupation()
|
||||||
|
gpu_memory_usage_after_release = get_gpu_memory_gb()
|
||||||
|
|
||||||
|
self.assertLess(
|
||||||
|
gpu_memory_usage_after_release,
|
||||||
|
gpu_memory_usage_before_release,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Release took {time.perf_counter() - t:.2f}s, memory: {gpu_memory_usage_before_release:.1f} GB → {gpu_memory_usage_after_release:.1f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
if _DEBUG_EXTRA:
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
t = time.perf_counter()
|
||||||
|
engine.resume_memory_occupation()
|
||||||
|
print(
|
||||||
|
f"Resume took {time.perf_counter() - t:.2f}s, memory: {get_gpu_memory_gb():.1f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("generate post resume")
|
||||||
|
outputs = engine.generate(params["prompt"], params["sampling_params"])["text"]
|
||||||
|
self.assertEqual(outputs, params["expect_output_before_update_weights"])
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
def test_multi_stage_release_and_resume(self):
|
def test_multi_stage_release_and_resume(self):
|
||||||
# With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM
|
# With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM
|
||||||
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|||||||
Reference in New Issue
Block a user