Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)
This commit is contained in:
@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
|
|||||||
openai = ["openai>=1.0", "tiktoken"]
|
openai = ["openai>=1.0", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0"]
|
anthropic = ["anthropic>=0.20.0"]
|
||||||
litellm = ["litellm>=1.0.0"]
|
litellm = ["litellm>=1.0.0"]
|
||||||
torch_memory_saver = ["torch_memory_saver>=0.0.4"]
|
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
|
||||||
decord = ["decord"]
|
decord = ["decord"]
|
||||||
test = [
|
test = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
|
|||||||
3
python/sglang/srt/constants.py
Normal file
3
python/sglang/srt/constants.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# GPU Memory Types
|
||||||
|
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
|
||||||
|
GPU_MEMORY_TYPE_WEIGHTS = "weights"
|
||||||
@@ -31,6 +31,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
FAKE_BOOTSTRAP_HOST,
|
FAKE_BOOTSTRAP_HOST,
|
||||||
@@ -90,7 +91,7 @@ class DecodeReqToTokenPool:
|
|||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
self.device = device
|
self.device = device
|
||||||
self.pre_alloc_size = pre_alloc_size
|
self.pre_alloc_size = pre_alloc_size
|
||||||
with memory_saver_adapter.region():
|
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
self.req_to_token = torch.zeros(
|
self.req_to_token = torch.zeros(
|
||||||
(size + pre_alloc_size, max_context_len),
|
(size + pre_alloc_size, max_context_len),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
|
|||||||
@@ -479,17 +479,15 @@ class Engine(EngineBase):
|
|||||||
self.tokenizer_manager.get_weights_by_name(obj, None)
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def release_memory_occupation(self):
|
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||||
"""Release GPU occupation temporarily."""
|
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
||||||
obj = ReleaseMemoryOccupationReqInput()
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.tokenizer_manager.release_memory_occupation(obj, None)
|
self.tokenizer_manager.release_memory_occupation(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def resume_memory_occupation(self):
|
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||||
"""Resume GPU occupation."""
|
obj = ResumeMemoryOccupationReqInput(tags=tags)
|
||||||
obj = ResumeMemoryOccupationReqInput()
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
||||||
@@ -670,11 +668,9 @@ def _launch_subprocesses(
|
|||||||
|
|
||||||
scheduler_procs = []
|
scheduler_procs = []
|
||||||
if server_args.dp_size == 1:
|
if server_args.dp_size == 1:
|
||||||
# Launch tensor parallel scheduler processes
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=server_args.enable_memory_saver
|
enable=server_args.enable_memory_saver
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_pipe_readers = []
|
scheduler_pipe_readers = []
|
||||||
|
|
||||||
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||||
@@ -710,6 +706,7 @@ def _launch_subprocesses(
|
|||||||
writer,
|
writer,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
with memory_saver_adapter.configure_subprocess():
|
with memory_saver_adapter.configure_subprocess():
|
||||||
proc.start()
|
proc.start()
|
||||||
scheduler_procs.append(proc)
|
scheduler_procs.append(proc)
|
||||||
|
|||||||
@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReleaseMemoryOccupationReqInput:
|
class ReleaseMemoryOccupationReqInput:
|
||||||
pass
|
# Optional tags to identify the memory region, which is primarily used for RL
|
||||||
|
# Currently we only support `weights` and `kv_cache`
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ResumeMemoryOccupationReqInput:
|
class ResumeMemoryOccupationReqInput:
|
||||||
pass
|
# Optional tags to identify the memory region, which is primarily used for RL
|
||||||
|
# Currently we only support `weights` and `kv_cache`
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from torch.distributed import barrier
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||||
from sglang.srt.constrained.base_grammar_backend import (
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
INVALID_GRAMMAR_OBJ,
|
INVALID_GRAMMAR_OBJ,
|
||||||
create_grammar_backend,
|
create_grammar_backend,
|
||||||
@@ -450,8 +451,6 @@ class Scheduler(
|
|||||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||||
t.start()
|
t.start()
|
||||||
self.parent_process = psutil.Process().parent()
|
self.parent_process = psutil.Process().parent()
|
||||||
|
|
||||||
# Init memory saver
|
|
||||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=server_args.enable_memory_saver
|
enable=server_args.enable_memory_saver
|
||||||
)
|
)
|
||||||
@@ -2227,23 +2226,40 @@ class Scheduler(
|
|||||||
return GetWeightsByNameReqOutput(parameter)
|
return GetWeightsByNameReqOutput(parameter)
|
||||||
|
|
||||||
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
||||||
self.memory_saver_adapter.check_validity(
|
tags = recv_req.tags
|
||||||
caller_name="release_memory_occupation"
|
import subprocess
|
||||||
)
|
|
||||||
|
if tags is None:
|
||||||
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||||
|
|
||||||
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||||
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
||||||
|
self.flush_cache()
|
||||||
|
|
||||||
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||||
self.stashed_model_static_state = _export_static_state(
|
self.stashed_model_static_state = _export_static_state(
|
||||||
self.tp_worker.worker.model_runner.model
|
self.tp_worker.worker.model_runner.model
|
||||||
)
|
)
|
||||||
self.memory_saver_adapter.pause()
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
||||||
self.flush_cache()
|
|
||||||
return ReleaseMemoryOccupationReqOutput()
|
return ReleaseMemoryOccupationReqOutput()
|
||||||
|
|
||||||
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
||||||
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
|
tags = recv_req.tags
|
||||||
self.memory_saver_adapter.resume()
|
if tags is None or len(tags) == 0:
|
||||||
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||||
|
|
||||||
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||||
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||||
_import_static_state(
|
_import_static_state(
|
||||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
self.tp_worker.worker.model_runner.model,
|
||||||
|
self.stashed_model_static_state,
|
||||||
)
|
)
|
||||||
del self.stashed_model_static_state
|
del self.stashed_model_static_state
|
||||||
|
|
||||||
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||||
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
||||||
|
|
||||||
return ResumeMemoryOccupationReqOutput()
|
return ResumeMemoryOccupationReqOutput()
|
||||||
|
|
||||||
def slow_down(self, recv_req: SlowDownReqInput):
|
def slow_down(self, recv_req: SlowDownReqInput):
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ class ReqToTokenPool:
|
|||||||
device: str,
|
device: str,
|
||||||
enable_memory_saver: bool,
|
enable_memory_saver: bool,
|
||||||
):
|
):
|
||||||
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=enable_memory_saver
|
enable=enable_memory_saver
|
||||||
)
|
)
|
||||||
@@ -61,7 +63,7 @@ class ReqToTokenPool:
|
|||||||
self.size = size
|
self.size = size
|
||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
self.device = device
|
self.device = device
|
||||||
with memory_saver_adapter.region():
|
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
self.req_to_token = torch.zeros(
|
self.req_to_token = torch.zeros(
|
||||||
(size, max_context_len), dtype=torch.int32, device=device
|
(size, max_context_len), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _create_buffers(self):
|
def _create_buffers(self):
|
||||||
with self.memory_saver_adapter.region():
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
with (
|
with (
|
||||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||||
if self.enable_custom_mem_pool
|
if self.enable_custom_mem_pool
|
||||||
@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
|
|||||||
else:
|
else:
|
||||||
self.custom_mem_pool = None
|
self.custom_mem_pool = None
|
||||||
|
|
||||||
with self.memory_saver_adapter.region():
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
with (
|
with (
|
||||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||||
if self.custom_mem_pool
|
if self.custom_mem_pool
|
||||||
@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
end_layer,
|
end_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.memory_saver_adapter.region():
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
# [size, head_num, head_dim] for each layer
|
# [size, head_num, head_dim] for each layer
|
||||||
self.k_buffer = [
|
self.k_buffer = [
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
|
|||||||
from sglang.srt.configs.device_config import DeviceConfig
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
get_world_group,
|
get_world_group,
|
||||||
@@ -222,6 +223,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def initialize(self, min_per_gpu_memory: float):
|
def initialize(self, min_per_gpu_memory: float):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=self.server_args.enable_memory_saver
|
enable=self.server_args.enable_memory_saver
|
||||||
)
|
)
|
||||||
@@ -547,7 +549,7 @@ class ModelRunner:
|
|||||||
monkey_patch_vllm_parallel_state()
|
monkey_patch_vllm_parallel_state()
|
||||||
monkey_patch_isinstance_for_vllm_base_layer()
|
monkey_patch_isinstance_for_vllm_base_layer()
|
||||||
|
|
||||||
with self.memory_saver_adapter.region():
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch_memory_saver
|
import torch_memory_saver
|
||||||
|
|
||||||
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
|
_memory_saver = torch_memory_saver.torch_memory_saver
|
||||||
import_error = None
|
import_error = None
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
import_error = e
|
import_error = e
|
||||||
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
|
|||||||
def configure_subprocess(self):
|
def configure_subprocess(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def region(self):
|
def region(self, tag: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def pause(self):
|
def pause(self, tag: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def resume(self):
|
def resume(self, tag: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
||||||
|
"""Adapter for TorchMemorySaver with tag-based control"""
|
||||||
|
|
||||||
def configure_subprocess(self):
|
def configure_subprocess(self):
|
||||||
return torch_memory_saver.configure_subprocess()
|
return torch_memory_saver.configure_subprocess()
|
||||||
|
|
||||||
def region(self):
|
def region(self, tag: str):
|
||||||
return _primary_memory_saver.region()
|
return _memory_saver.region(tag=tag)
|
||||||
|
|
||||||
def pause(self):
|
def pause(self, tag: str):
|
||||||
return _primary_memory_saver.pause()
|
return _memory_saver.pause(tag=tag)
|
||||||
|
|
||||||
def resume(self):
|
def resume(self, tag: str):
|
||||||
return _primary_memory_saver.resume()
|
return _memory_saver.resume(tag=tag)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enabled(self):
|
def enabled(self):
|
||||||
return _primary_memory_saver.enabled
|
return _memory_saver is not None and _memory_saver.enabled
|
||||||
|
|
||||||
|
|
||||||
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
||||||
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def region(self):
|
def region(self, tag: str):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def pause(self):
|
def pause(self, tag: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def resume(self):
|
def resume(self, tag: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback
|
|||||||
# General test models
|
# General test models
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
|
||||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ suites = {
|
|||||||
TestFile("test_radix_attention.py", 105),
|
TestFile("test_radix_attention.py", 105),
|
||||||
TestFile("test_reasoning_content.py", 89),
|
TestFile("test_reasoning_content.py", 89),
|
||||||
TestFile("test_regex_constrained.py", 64),
|
TestFile("test_regex_constrained.py", 64),
|
||||||
TestFile("test_release_memory_occupation.py", 44),
|
|
||||||
TestFile("test_request_length_validation.py", 31),
|
TestFile("test_request_length_validation.py", 31),
|
||||||
TestFile("test_retract_decode.py", 54),
|
TestFile("test_retract_decode.py", 54),
|
||||||
TestFile("test_server_args.py", 1),
|
TestFile("test_server_args.py", 1),
|
||||||
@@ -146,6 +145,7 @@ suites = {
|
|||||||
TestFile("test_patch_torch.py", 19),
|
TestFile("test_patch_torch.py", 19),
|
||||||
TestFile("test_update_weights_from_distributed.py", 103),
|
TestFile("test_update_weights_from_distributed.py", 103),
|
||||||
TestFile("test_verl_engine_2_gpu.py", 64),
|
TestFile("test_verl_engine_2_gpu.py", 64),
|
||||||
|
TestFile("test_release_memory_occupation.py", 44),
|
||||||
],
|
],
|
||||||
"per-commit-2-gpu-amd": [
|
"per-commit-2-gpu-amd": [
|
||||||
TestFile("models/lora/test_lora_tp.py", 116),
|
TestFile("models/lora/test_lora_tp.py", 116),
|
||||||
|
|||||||
@@ -1,3 +1,32 @@
|
|||||||
|
"""Test memory release and resume operations for SGLang engine in hybrid RL training.
|
||||||
|
|
||||||
|
This test suite evaluates the SGLang engine's memory management capabilities, focusing
|
||||||
|
on releasing and resuming memory occupation for KV cache and model weights. It simulates
|
||||||
|
an RL workflow where the SGLang engine acts as a rollout engine for experience collection.
|
||||||
|
The process involves initializing the engine, sending a small number of requests to simulate
|
||||||
|
rollout, releasing memory to mimic offloading during RL training, resuming memory occupation,
|
||||||
|
updating weights with a trained HuggingFace model, and verifying the updated weights.
|
||||||
|
|
||||||
|
Detailed in our proposal (https://github.com/sgl-project/sglang/pull/7099), two test cases
|
||||||
|
are included:
|
||||||
|
|
||||||
|
1. Basic Release and Resume: Uses a lower mem_fraction_static (0.6) to control memory allocation
|
||||||
|
and avoid OOM errors carefully. This test simulates a scenario without multi-stage memory management,
|
||||||
|
ensuring the engine can release and resume memory occupation while maintaining functionality after
|
||||||
|
weight updates.
|
||||||
|
|
||||||
|
2. Multi-Stage Release and Resume: Employs a higher mem_fraction_static (0.85) to simulate higher
|
||||||
|
memory pressure, leveraging multi-stage memory management. It sequentially releases and resumes
|
||||||
|
KV cache and model weights, verifying memory deallocation and reallocation at each stage, and
|
||||||
|
ensuring correct weight updates and text generation.
|
||||||
|
|
||||||
|
3. Tensor Parallel Tests: Tests memory release and resume operations with different tensor parallel
|
||||||
|
configurations (tp=1, tp=2) to ensure proper memory management in distributed settings. For different
|
||||||
|
data parallel size, we test it in verl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -5,93 +34,221 @@ import torch
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
||||||
|
CustomTestCase,
|
||||||
|
)
|
||||||
|
|
||||||
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
|
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
|
||||||
_DEBUG_EXTRA = True
|
_DEBUG_EXTRA = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_memory_gb():
|
||||||
|
return torch.cuda.device_memory_used() / 1024**3
|
||||||
|
|
||||||
|
|
||||||
class TestReleaseMemoryOccupation(CustomTestCase):
|
class TestReleaseMemoryOccupation(CustomTestCase):
|
||||||
def test_release_and_resume_occupation(self):
|
def _setup_engine(self, model_name, mem_fraction_static=0.8, tp_size=1):
|
||||||
prompt = "Today is a sunny day and I like"
|
"""Common setup for engine and HF model."""
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
|
||||||
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
||||||
expect_output = " to spend it outdoors. I decided to"
|
|
||||||
|
|
||||||
engine = sgl.Engine(
|
engine = sgl.Engine(
|
||||||
model_path=model_name,
|
model_path=model_name,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
enable_memory_saver=True,
|
enable_memory_saver=True,
|
||||||
|
mem_fraction_static=mem_fraction_static,
|
||||||
|
tp_size=tp_size,
|
||||||
# disable_cuda_graph=True, # for debugging only
|
# disable_cuda_graph=True, # for debugging only
|
||||||
)
|
)
|
||||||
hf_model_new = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name, torch_dtype="bfloat16"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
def _common_test_params(self):
|
||||||
|
"""Common test parameters."""
|
||||||
|
return {
|
||||||
|
"prompt": "Today is a sunny day and I like",
|
||||||
|
"sampling_params": {"temperature": 0, "max_new_tokens": 8},
|
||||||
|
"expect_output_before_update_weights": " to spend it outdoors. I decided to",
|
||||||
|
"expect_output_after_update_weights": " to go for a walk. I like",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _test_initial_generation(
|
||||||
|
self, engine, prompt, sampling_params, expect_output_before_update_weights
|
||||||
|
):
|
||||||
|
"""Test initial generation and memory allocation."""
|
||||||
print("generate (#1)")
|
print("generate (#1)")
|
||||||
outputs = engine.generate(prompt, sampling_params)["text"]
|
outputs = engine.generate(prompt, sampling_params)["text"]
|
||||||
self.assertEqual(outputs, expect_output)
|
self.assertEqual(outputs, expect_output_before_update_weights)
|
||||||
|
|
||||||
if _DEBUG_EXTRA:
|
if _DEBUG_EXTRA:
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
self.assertEqual(
|
def test_release_and_resume_occupation(self):
|
||||||
_try_allocate_big_tensor(),
|
# Without multi-stage release and resume, we need to carefully control the memory fraction to avoid OOM
|
||||||
False,
|
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
"Should not be able to allocate big tensors before releasing",
|
assert (
|
||||||
|
torch.cuda.device_count() >= 2
|
||||||
|
), "Need at least 2 GPUs for tensor parallel tests"
|
||||||
|
|
||||||
|
for tp_size in [1, 2]:
|
||||||
|
|
||||||
|
print(f"Testing tp_size={tp_size} for test_release_and_resume_occupation")
|
||||||
|
engine = self._setup_engine(
|
||||||
|
model_name=model_name, mem_fraction_static=0.6, tp_size=tp_size
|
||||||
|
)
|
||||||
|
params = self._common_test_params()
|
||||||
|
|
||||||
|
self._test_initial_generation(
|
||||||
|
engine,
|
||||||
|
params["prompt"],
|
||||||
|
params["sampling_params"],
|
||||||
|
params["expect_output_before_update_weights"],
|
||||||
)
|
)
|
||||||
|
|
||||||
print("release_memory_occupation start")
|
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
|
gpu_memory_usage_before_release = get_gpu_memory_gb()
|
||||||
engine.release_memory_occupation()
|
engine.release_memory_occupation()
|
||||||
if _DEBUG_EXTRA:
|
gpu_memory_usage_after_release = get_gpu_memory_gb()
|
||||||
print("release_memory_occupation", time.perf_counter() - t)
|
|
||||||
|
|
||||||
if _DEBUG_EXTRA:
|
self.assertLess(
|
||||||
time.sleep(5)
|
gpu_memory_usage_after_release,
|
||||||
|
gpu_memory_usage_before_release,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
print(
|
||||||
_try_allocate_big_tensor(),
|
f"Release took {time.perf_counter() - t:.2f}s, memory: {gpu_memory_usage_before_release:.1f} GB → {gpu_memory_usage_after_release:.1f} GB"
|
||||||
True,
|
|
||||||
"Should be able to allocate big tensors aftre releasing",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if _DEBUG_EXTRA:
|
if _DEBUG_EXTRA:
|
||||||
time.sleep(5)
|
time.sleep(3)
|
||||||
|
|
||||||
print("resume_memory_occupation start")
|
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
engine.resume_memory_occupation()
|
engine.resume_memory_occupation()
|
||||||
if _DEBUG_EXTRA:
|
print(
|
||||||
print("resume_memory_occupation", time.perf_counter() - t)
|
f"Resume took {time.perf_counter() - t:.2f}s, memory: {get_gpu_memory_gb():.1f} GB"
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
_try_allocate_big_tensor(),
|
|
||||||
False,
|
|
||||||
"Should not be able to allocate big tensors after resuming",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print("update_weights_from_tensor")
|
hf_model_new = AutoModelForCausalLM.from_pretrained(
|
||||||
# As if: PPO has updated hf model's weights, and now we sync it to SGLang
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
||||||
|
torch_dtype="bfloat16",
|
||||||
|
device_map="cuda",
|
||||||
|
)
|
||||||
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
|
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
|
||||||
|
|
||||||
|
# destroy the hf model
|
||||||
|
del hf_model_new
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
print("generate (#2)")
|
print("generate (#2)")
|
||||||
outputs = engine.generate(prompt, sampling_params)["text"]
|
outputs = engine.generate(params["prompt"], params["sampling_params"])[
|
||||||
self.assertEqual(outputs, expect_output)
|
"text"
|
||||||
|
]
|
||||||
if _DEBUG_EXTRA:
|
self.assertEqual(outputs, params["expect_output_after_update_weights"])
|
||||||
time.sleep(4)
|
|
||||||
|
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
|
def test_multi_stage_release_and_resume(self):
|
||||||
|
# With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM
|
||||||
|
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
def _try_allocate_big_tensor(size: int = 20_000_000_000):
|
for tp_size in [1, 2]:
|
||||||
try:
|
if tp_size == 2 and torch.cuda.device_count() < 2:
|
||||||
torch.empty((size,), dtype=torch.uint8, device="cuda")
|
continue
|
||||||
|
|
||||||
|
print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume")
|
||||||
|
engine = sgl.Engine(
|
||||||
|
model_path=model_name,
|
||||||
|
random_seed=42,
|
||||||
|
enable_memory_saver=True,
|
||||||
|
mem_fraction_static=0.85, # Higher memory pressure
|
||||||
|
tp_size=tp_size,
|
||||||
|
)
|
||||||
|
params = self._common_test_params()
|
||||||
|
|
||||||
|
self._test_initial_generation(
|
||||||
|
engine,
|
||||||
|
params["prompt"],
|
||||||
|
params["sampling_params"],
|
||||||
|
params["expect_output_before_update_weights"],
|
||||||
|
)
|
||||||
|
|
||||||
|
t = time.perf_counter()
|
||||||
|
gpu_memory_usage_before_release_kv_cache = get_gpu_memory_gb()
|
||||||
|
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE])
|
||||||
|
|
||||||
|
gpu_memory_usage_after_release_kv_cache = get_gpu_memory_gb()
|
||||||
|
|
||||||
|
self.assertLess(
|
||||||
|
gpu_memory_usage_after_release_kv_cache,
|
||||||
|
gpu_memory_usage_before_release_kv_cache,
|
||||||
|
)
|
||||||
|
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
|
||||||
|
|
||||||
|
gpu_memory_usage_after_release_weights = get_gpu_memory_gb()
|
||||||
|
|
||||||
|
self.assertLess(
|
||||||
|
gpu_memory_usage_after_release_weights,
|
||||||
|
gpu_memory_usage_after_release_kv_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Release took {time.perf_counter() - t:.2f}s")
|
||||||
|
print(
|
||||||
|
f"Memory: {gpu_memory_usage_before_release_kv_cache:.1f} → {gpu_memory_usage_after_release_kv_cache:.1f} → {gpu_memory_usage_after_release_weights:.1f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
if _DEBUG_EXTRA:
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
t = time.perf_counter()
|
||||||
|
gpu_memory_usage_before_resume_weights = get_gpu_memory_gb()
|
||||||
|
|
||||||
|
# gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume_weights should be close
|
||||||
|
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
gpu_memory_usage_after_release_weights,
|
||||||
|
gpu_memory_usage_before_resume_weights,
|
||||||
|
delta=3.0,
|
||||||
|
)
|
||||||
|
print(f"Resume weights took {time.perf_counter() - t:.2f}s")
|
||||||
|
|
||||||
|
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
|
||||||
|
gpu_memory_usage_after_resume_weights = get_gpu_memory_gb()
|
||||||
|
|
||||||
|
self.assertGreater(
|
||||||
|
gpu_memory_usage_after_resume_weights,
|
||||||
|
gpu_memory_usage_before_resume_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update weights from a trained model to serving engine, and then destroy the trained model
|
||||||
|
hf_model_new = AutoModelForCausalLM.from_pretrained(
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
||||||
|
torch_dtype="bfloat16",
|
||||||
|
device_map="cuda",
|
||||||
|
)
|
||||||
|
gpu_memory_usage_after_loaded_hf_model = get_gpu_memory_gb()
|
||||||
|
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
|
||||||
|
|
||||||
|
# destroy the hf model
|
||||||
|
del hf_model_new
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return True
|
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE])
|
||||||
except torch.cuda.OutOfMemoryError:
|
|
||||||
return False
|
gpu_memory_usage_after_resume_kv_cache = get_gpu_memory_gb()
|
||||||
|
self.assertGreater(
|
||||||
|
gpu_memory_usage_after_resume_kv_cache,
|
||||||
|
gpu_memory_usage_after_resume_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Resume + update took {time.perf_counter() - t:.2f}s")
|
||||||
|
print(
|
||||||
|
f"Memory: {gpu_memory_usage_before_resume_weights:.1f} → {gpu_memory_usage_after_resume_weights:.1f} → {gpu_memory_usage_after_loaded_hf_model:.1f} → {gpu_memory_usage_after_resume_kv_cache:.1f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("generate (#2)")
|
||||||
|
outputs = engine.generate(params["prompt"], params["sampling_params"])[
|
||||||
|
"text"
|
||||||
|
]
|
||||||
|
self.assertEqual(outputs, params["expect_output_after_update_weights"])
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -235,6 +235,7 @@ def _run_subprocess(
|
|||||||
output_writer.send(execution_ok)
|
output_writer.send(execution_ok)
|
||||||
output_writer.close()
|
output_writer.close()
|
||||||
|
|
||||||
|
if "engine" in locals() and engine is not None:
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
print(f"subprocess[{rank=}] end", flush=True)
|
print(f"subprocess[{rank=}] end", flush=True)
|
||||||
|
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ def _run_subprocess(
|
|||||||
output_writer.send(execution_ok)
|
output_writer.send(execution_ok)
|
||||||
output_writer.close()
|
output_writer.close()
|
||||||
|
|
||||||
|
if "engine" in locals() and engine is not None:
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
print(f"subprocess[{rank=}] end", flush=True)
|
print(f"subprocess[{rank=}] end", flush=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user