[Feature] Option to save model weights to CPU when memory saver mode is enabled (#10873)
Co-authored-by: molocule <34072934+molocule@users.noreply.github.com>
This commit is contained in:
@@ -58,7 +58,7 @@ dependencies = [
|
||||
"tiktoken",
|
||||
"timm==1.0.16",
|
||||
"torch==2.8.0",
|
||||
"torch_memory_saver==0.0.8",
|
||||
"torch_memory_saver==0.0.9rc1",
|
||||
"torchao==0.9.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
|
||||
@@ -110,7 +110,7 @@ srt_hpu = ["sglang[runtime_common]"]
|
||||
openai = ["openai==1.99.1", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
torch_memory_saver = ["torch_memory_saver==0.0.8"]
|
||||
torch_memory_saver = ["torch_memory_saver==0.0.9rc1"]
|
||||
decord = ["decord"]
|
||||
test = [
|
||||
"accelerate",
|
||||
|
||||
@@ -25,9 +25,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -36,7 +34,6 @@ from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
||||
from sglang.srt.connector import ConnectorType
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
@@ -132,7 +129,6 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_cpu_ids_by_node,
|
||||
init_custom_process_group,
|
||||
is_blackwell,
|
||||
is_fa3_default_architecture,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
@@ -143,7 +139,6 @@ from sglang.srt.utils import (
|
||||
log_info_on_rank0,
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
parse_connector_type,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.srt.weight_sync.tensor_bucket import (
|
||||
@@ -616,7 +611,7 @@ class ModelRunner:
|
||||
server_args.hicache_io_backend = "direct"
|
||||
logger.warning(
|
||||
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
||||
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
||||
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
||||
)
|
||||
|
||||
def init_torch_distributed(self):
|
||||
@@ -778,7 +773,10 @@ class ModelRunner:
|
||||
monkey_patch_vllm_parallel_state()
|
||||
monkey_patch_isinstance_for_vllm_base_layer()
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
||||
with self.memory_saver_adapter.region(
|
||||
GPU_MEMORY_TYPE_WEIGHTS,
|
||||
enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
|
||||
):
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
@@ -1106,7 +1104,7 @@ class ModelRunner:
|
||||
handle.wait()
|
||||
|
||||
self.model.load_weights(weights)
|
||||
return True, f"Succeeded to update parameter online."
|
||||
return True, "Succeeded to update parameter online."
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
@@ -1749,8 +1747,8 @@ class ModelRunner:
|
||||
f"prefill_backend={self.prefill_attention_backend_str}."
|
||||
)
|
||||
logger.warning(
|
||||
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
||||
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
||||
"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
||||
"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
||||
)
|
||||
else:
|
||||
attn_backend = self._get_attention_backend_from_str(
|
||||
|
||||
@@ -400,6 +400,7 @@ class ServerArgs:
|
||||
num_continuous_decode_steps: int = 1
|
||||
delete_ckpt_after_loading: bool = False
|
||||
enable_memory_saver: bool = False
|
||||
enable_weights_cpu_backup: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
enable_custom_logit_processor: bool = False
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
@@ -2541,6 +2542,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-weights-cpu-backup",
|
||||
action="store_true",
|
||||
help="Save model weights to CPU memory during release_weights_occupation and resume_weights_occupation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow-auto-truncate",
|
||||
action="store_true",
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import contextmanager
|
||||
|
||||
try:
|
||||
import torch_memory_saver
|
||||
@@ -40,7 +38,7 @@ class TorchMemorySaverAdapter(ABC):
|
||||
def configure_subprocess(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def region(self, tag: str):
|
||||
def region(self, tag: str, enable_cpu_backup: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def pause(self, tag: str):
|
||||
@@ -60,8 +58,8 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
||||
def configure_subprocess(self):
|
||||
return torch_memory_saver.configure_subprocess()
|
||||
|
||||
def region(self, tag: str):
|
||||
return _memory_saver.region(tag=tag)
|
||||
def region(self, tag: str, enable_cpu_backup: bool = False):
|
||||
return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
|
||||
|
||||
def pause(self, tag: str):
|
||||
return _memory_saver.pause(tag=tag)
|
||||
@@ -80,7 +78,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def region(self, tag: str):
|
||||
def region(self, tag: str, enable_cpu_backup: bool = False):
|
||||
yield
|
||||
|
||||
def pause(self, tag: str):
|
||||
|
||||
Reference in New Issue
Block a user