Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)
This commit is contained in:
@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
torch_memory_saver = ["torch_memory_saver>=0.0.4"]
|
||||
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
|
||||
decord = ["decord"]
|
||||
test = [
|
||||
"accelerate",
|
||||
|
||||
3
python/sglang/srt/constants.py
Normal file
3
python/sglang/srt/constants.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# GPU Memory Types
|
||||
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
|
||||
GPU_MEMORY_TYPE_WEIGHTS = "weights"
|
||||
@@ -31,6 +31,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
@@ -90,7 +91,7 @@ class DecodeReqToTokenPool:
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
self.pre_alloc_size = pre_alloc_size
|
||||
with memory_saver_adapter.region():
|
||||
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
|
||||
self.req_to_token = torch.zeros(
|
||||
(size + pre_alloc_size, max_context_len),
|
||||
dtype=torch.int32,
|
||||
|
||||
@@ -479,17 +479,15 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.get_weights_by_name(obj, None)
|
||||
)
|
||||
|
||||
def release_memory_occupation(self):
|
||||
"""Release GPU occupation temporarily."""
|
||||
obj = ReleaseMemoryOccupationReqInput()
|
||||
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.release_memory_occupation(obj, None)
|
||||
)
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
"""Resume GPU occupation."""
|
||||
obj = ResumeMemoryOccupationReqInput()
|
||||
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||
obj = ResumeMemoryOccupationReqInput(tags=tags)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
||||
@@ -670,11 +668,9 @@ def _launch_subprocesses(
|
||||
|
||||
scheduler_procs = []
|
||||
if server_args.dp_size == 1:
|
||||
# Launch tensor parallel scheduler processes
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
scheduler_pipe_readers = []
|
||||
|
||||
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||
@@ -710,6 +706,7 @@ def _launch_subprocesses(
|
||||
writer,
|
||||
),
|
||||
)
|
||||
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
|
||||
@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:
|
||||
|
||||
@dataclass
|
||||
class ReleaseMemoryOccupationReqInput:
|
||||
pass
|
||||
# Optional tags to identify the memory region, which is primarily used for RL
|
||||
# Currently we only support `weights` and `kv_cache`
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:
|
||||
|
||||
@dataclass
|
||||
class ResumeMemoryOccupationReqInput:
|
||||
pass
|
||||
# Optional tags to identify the memory region, which is primarily used for RL
|
||||
# Currently we only support `weights` and `kv_cache`
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -36,6 +36,7 @@ from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
create_grammar_backend,
|
||||
@@ -450,8 +451,6 @@ class Scheduler(
|
||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||
t.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
|
||||
# Init memory saver
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
@@ -2227,23 +2226,40 @@ class Scheduler(
|
||||
return GetWeightsByNameReqOutput(parameter)
|
||||
|
||||
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
||||
self.memory_saver_adapter.check_validity(
|
||||
caller_name="release_memory_occupation"
|
||||
)
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
self.memory_saver_adapter.pause()
|
||||
self.flush_cache()
|
||||
tags = recv_req.tags
|
||||
import subprocess
|
||||
|
||||
if tags is None:
|
||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||
|
||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
||||
self.flush_cache()
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
|
||||
return ReleaseMemoryOccupationReqOutput()
|
||||
|
||||
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
||||
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
|
||||
self.memory_saver_adapter.resume()
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
tags = recv_req.tags
|
||||
if tags is None or len(tags) == 0:
|
||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model,
|
||||
self.stashed_model_static_state,
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
|
||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
||||
|
||||
return ResumeMemoryOccupationReqOutput()
|
||||
|
||||
def slow_down(self, recv_req: SlowDownReqInput):
|
||||
|
||||
@@ -35,6 +35,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
||||
|
||||
@@ -54,6 +55,7 @@ class ReqToTokenPool:
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
@@ -61,7 +63,7 @@ class ReqToTokenPool:
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
with memory_saver_adapter.region():
|
||||
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.enable_custom_mem_pool
|
||||
@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
end_layer,
|
||||
)
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
|
||||
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.distributed import (
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
@@ -222,6 +223,7 @@ class ModelRunner:
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
server_args = self.server_args
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=self.server_args.enable_memory_saver
|
||||
)
|
||||
@@ -547,7 +549,7 @@ class ModelRunner:
|
||||
monkey_patch_vllm_parallel_state()
|
||||
monkey_patch_isinstance_for_vllm_base_layer()
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
try:
|
||||
import torch_memory_saver
|
||||
|
||||
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
|
||||
_memory_saver = torch_memory_saver.torch_memory_saver
|
||||
import_error = None
|
||||
except ImportError as e:
|
||||
import_error = e
|
||||
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
|
||||
def configure_subprocess(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def region(self):
|
||||
def region(self, tag: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def pause(self):
|
||||
def pause(self, tag: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def resume(self):
|
||||
def resume(self, tag: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
|
||||
|
||||
|
||||
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
||||
"""Adapter for TorchMemorySaver with tag-based control"""
|
||||
|
||||
def configure_subprocess(self):
|
||||
return torch_memory_saver.configure_subprocess()
|
||||
|
||||
def region(self):
|
||||
return _primary_memory_saver.region()
|
||||
def region(self, tag: str):
|
||||
return _memory_saver.region(tag=tag)
|
||||
|
||||
def pause(self):
|
||||
return _primary_memory_saver.pause()
|
||||
def pause(self, tag: str):
|
||||
return _memory_saver.pause(tag=tag)
|
||||
|
||||
def resume(self):
|
||||
return _primary_memory_saver.resume()
|
||||
def resume(self, tag: str):
|
||||
return _memory_saver.resume(tag=tag)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return _primary_memory_saver.enabled
|
||||
return _memory_saver is not None and _memory_saver.enabled
|
||||
|
||||
|
||||
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
||||
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def region(self):
|
||||
def region(self, tag: str):
|
||||
yield
|
||||
|
||||
def pause(self):
|
||||
def pause(self, tag: str):
|
||||
pass
|
||||
|
||||
def resume(self):
|
||||
def resume(self, tag: str):
|
||||
pass
|
||||
|
||||
@property
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback
|
||||
# General test models
|
||||
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||
|
||||
|
||||
@@ -74,7 +74,6 @@ suites = {
|
||||
TestFile("test_radix_attention.py", 105),
|
||||
TestFile("test_reasoning_content.py", 89),
|
||||
TestFile("test_regex_constrained.py", 64),
|
||||
TestFile("test_release_memory_occupation.py", 44),
|
||||
TestFile("test_request_length_validation.py", 31),
|
||||
TestFile("test_retract_decode.py", 54),
|
||||
TestFile("test_server_args.py", 1),
|
||||
@@ -146,6 +145,7 @@ suites = {
|
||||
TestFile("test_patch_torch.py", 19),
|
||||
TestFile("test_update_weights_from_distributed.py", 103),
|
||||
TestFile("test_verl_engine_2_gpu.py", 64),
|
||||
TestFile("test_release_memory_occupation.py", 44),
|
||||
],
|
||||
"per-commit-2-gpu-amd": [
|
||||
TestFile("models/lora/test_lora_tp.py", 116),
|
||||
|
||||
@@ -1,3 +1,32 @@
|
||||
"""Test memory release and resume operations for SGLang engine in hybrid RL training.
|
||||
|
||||
This test suite evaluates the SGLang engine's memory management capabilities, focusing
|
||||
on releasing and resuming memory occupation for KV cache and model weights. It simulates
|
||||
an RL workflow where the SGLang engine acts as a rollout engine for experience collection.
|
||||
The process involves initializing the engine, sending a small number of requests to simulate
|
||||
rollout, releasing memory to mimic offloading during RL training, resuming memory occupation,
|
||||
updating weights with a trained HuggingFace model, and verifying the updated weights.
|
||||
|
||||
Detailed in our proposal (https://github.com/sgl-project/sglang/pull/7099), two test cases
|
||||
are included:
|
||||
|
||||
1. Basic Release and Resume: Uses a lower mem_fraction_static (0.6) to control memory allocation
|
||||
and avoid OOM errors carefully. This test simulates a scenario without multi-stage memory management,
|
||||
ensuring the engine can release and resume memory occupation while maintaining functionality after
|
||||
weight updates.
|
||||
|
||||
2. Multi-Stage Release and Resume: Employs a higher mem_fraction_static (0.85) to simulate higher
|
||||
memory pressure, leveraging multi-stage memory management. It sequentially releases and resumes
|
||||
KV cache and model weights, verifying memory deallocation and reallocation at each stage, and
|
||||
ensuring correct weight updates and text generation.
|
||||
|
||||
3. Tensor Parallel Tests: Tests memory release and resume operations with different tensor parallel
|
||||
configurations (tp=1, tp=2) to ensure proper memory management in distributed settings. For different
|
||||
data parallel size, we test it in verl.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
@@ -5,93 +34,221 @@ import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
||||
CustomTestCase,
|
||||
)
|
||||
|
||||
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
|
||||
_DEBUG_EXTRA = True
|
||||
_DEBUG_EXTRA = False
|
||||
|
||||
|
||||
def get_gpu_memory_gb():
|
||||
return torch.cuda.device_memory_used() / 1024**3
|
||||
|
||||
|
||||
class TestReleaseMemoryOccupation(CustomTestCase):
|
||||
def test_release_and_resume_occupation(self):
|
||||
prompt = "Today is a sunny day and I like"
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
expect_output = " to spend it outdoors. I decided to"
|
||||
|
||||
def _setup_engine(self, model_name, mem_fraction_static=0.8, tp_size=1):
|
||||
"""Common setup for engine and HF model."""
|
||||
engine = sgl.Engine(
|
||||
model_path=model_name,
|
||||
random_seed=42,
|
||||
enable_memory_saver=True,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
tp_size=tp_size,
|
||||
# disable_cuda_graph=True, # for debugging only
|
||||
)
|
||||
hf_model_new = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype="bfloat16"
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
def _common_test_params(self):
|
||||
"""Common test parameters."""
|
||||
return {
|
||||
"prompt": "Today is a sunny day and I like",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 8},
|
||||
"expect_output_before_update_weights": " to spend it outdoors. I decided to",
|
||||
"expect_output_after_update_weights": " to go for a walk. I like",
|
||||
}
|
||||
|
||||
def _test_initial_generation(
|
||||
self, engine, prompt, sampling_params, expect_output_before_update_weights
|
||||
):
|
||||
"""Test initial generation and memory allocation."""
|
||||
print("generate (#1)")
|
||||
outputs = engine.generate(prompt, sampling_params)["text"]
|
||||
self.assertEqual(outputs, expect_output)
|
||||
self.assertEqual(outputs, expect_output_before_update_weights)
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(3)
|
||||
|
||||
self.assertEqual(
|
||||
_try_allocate_big_tensor(),
|
||||
False,
|
||||
"Should not be able to allocate big tensors before releasing",
|
||||
)
|
||||
def test_release_and_resume_occupation(self):
|
||||
# Without multi-stage release and resume, we need to carefully control the memory fraction to avoid OOM
|
||||
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
assert (
|
||||
torch.cuda.device_count() >= 2
|
||||
), "Need at least 2 GPUs for tensor parallel tests"
|
||||
|
||||
print("release_memory_occupation start")
|
||||
t = time.perf_counter()
|
||||
engine.release_memory_occupation()
|
||||
if _DEBUG_EXTRA:
|
||||
print("release_memory_occupation", time.perf_counter() - t)
|
||||
for tp_size in [1, 2]:
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(5)
|
||||
print(f"Testing tp_size={tp_size} for test_release_and_resume_occupation")
|
||||
engine = self._setup_engine(
|
||||
model_name=model_name, mem_fraction_static=0.6, tp_size=tp_size
|
||||
)
|
||||
params = self._common_test_params()
|
||||
|
||||
self.assertEqual(
|
||||
_try_allocate_big_tensor(),
|
||||
True,
|
||||
"Should be able to allocate big tensors aftre releasing",
|
||||
)
|
||||
self._test_initial_generation(
|
||||
engine,
|
||||
params["prompt"],
|
||||
params["sampling_params"],
|
||||
params["expect_output_before_update_weights"],
|
||||
)
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(5)
|
||||
t = time.perf_counter()
|
||||
gpu_memory_usage_before_release = get_gpu_memory_gb()
|
||||
engine.release_memory_occupation()
|
||||
gpu_memory_usage_after_release = get_gpu_memory_gb()
|
||||
|
||||
print("resume_memory_occupation start")
|
||||
t = time.perf_counter()
|
||||
engine.resume_memory_occupation()
|
||||
if _DEBUG_EXTRA:
|
||||
print("resume_memory_occupation", time.perf_counter() - t)
|
||||
self.assertLess(
|
||||
gpu_memory_usage_after_release,
|
||||
gpu_memory_usage_before_release,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
_try_allocate_big_tensor(),
|
||||
False,
|
||||
"Should not be able to allocate big tensors after resuming",
|
||||
)
|
||||
print(
|
||||
f"Release took {time.perf_counter() - t:.2f}s, memory: {gpu_memory_usage_before_release:.1f} GB → {gpu_memory_usage_after_release:.1f} GB"
|
||||
)
|
||||
|
||||
print("update_weights_from_tensor")
|
||||
# As if: PPO has updated hf model's weights, and now we sync it to SGLang
|
||||
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(3)
|
||||
|
||||
print("generate (#2)")
|
||||
outputs = engine.generate(prompt, sampling_params)["text"]
|
||||
self.assertEqual(outputs, expect_output)
|
||||
t = time.perf_counter()
|
||||
engine.resume_memory_occupation()
|
||||
print(
|
||||
f"Resume took {time.perf_counter() - t:.2f}s, memory: {get_gpu_memory_gb():.1f} GB"
|
||||
)
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(4)
|
||||
hf_model_new = AutoModelForCausalLM.from_pretrained(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
||||
torch_dtype="bfloat16",
|
||||
device_map="cuda",
|
||||
)
|
||||
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
|
||||
|
||||
engine.shutdown()
|
||||
# destroy the hf model
|
||||
del hf_model_new
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print("generate (#2)")
|
||||
outputs = engine.generate(params["prompt"], params["sampling_params"])[
|
||||
"text"
|
||||
]
|
||||
self.assertEqual(outputs, params["expect_output_after_update_weights"])
|
||||
engine.shutdown()
|
||||
|
||||
def _try_allocate_big_tensor(size: int = 20_000_000_000):
|
||||
try:
|
||||
torch.empty((size,), dtype=torch.uint8, device="cuda")
|
||||
torch.cuda.empty_cache()
|
||||
return True
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
return False
|
||||
def test_multi_stage_release_and_resume(self):
|
||||
# With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM
|
||||
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
for tp_size in [1, 2]:
|
||||
if tp_size == 2 and torch.cuda.device_count() < 2:
|
||||
continue
|
||||
|
||||
print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume")
|
||||
engine = sgl.Engine(
|
||||
model_path=model_name,
|
||||
random_seed=42,
|
||||
enable_memory_saver=True,
|
||||
mem_fraction_static=0.85, # Higher memory pressure
|
||||
tp_size=tp_size,
|
||||
)
|
||||
params = self._common_test_params()
|
||||
|
||||
self._test_initial_generation(
|
||||
engine,
|
||||
params["prompt"],
|
||||
params["sampling_params"],
|
||||
params["expect_output_before_update_weights"],
|
||||
)
|
||||
|
||||
t = time.perf_counter()
|
||||
gpu_memory_usage_before_release_kv_cache = get_gpu_memory_gb()
|
||||
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE])
|
||||
|
||||
gpu_memory_usage_after_release_kv_cache = get_gpu_memory_gb()
|
||||
|
||||
self.assertLess(
|
||||
gpu_memory_usage_after_release_kv_cache,
|
||||
gpu_memory_usage_before_release_kv_cache,
|
||||
)
|
||||
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
|
||||
|
||||
gpu_memory_usage_after_release_weights = get_gpu_memory_gb()
|
||||
|
||||
self.assertLess(
|
||||
gpu_memory_usage_after_release_weights,
|
||||
gpu_memory_usage_after_release_kv_cache,
|
||||
)
|
||||
|
||||
print(f"Release took {time.perf_counter() - t:.2f}s")
|
||||
print(
|
||||
f"Memory: {gpu_memory_usage_before_release_kv_cache:.1f} → {gpu_memory_usage_after_release_kv_cache:.1f} → {gpu_memory_usage_after_release_weights:.1f} GB"
|
||||
)
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(3)
|
||||
|
||||
t = time.perf_counter()
|
||||
gpu_memory_usage_before_resume_weights = get_gpu_memory_gb()
|
||||
|
||||
# gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume_weights should be close
|
||||
|
||||
self.assertAlmostEqual(
|
||||
gpu_memory_usage_after_release_weights,
|
||||
gpu_memory_usage_before_resume_weights,
|
||||
delta=3.0,
|
||||
)
|
||||
print(f"Resume weights took {time.perf_counter() - t:.2f}s")
|
||||
|
||||
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
|
||||
gpu_memory_usage_after_resume_weights = get_gpu_memory_gb()
|
||||
|
||||
self.assertGreater(
|
||||
gpu_memory_usage_after_resume_weights,
|
||||
gpu_memory_usage_before_resume_weights,
|
||||
)
|
||||
|
||||
# Update weights from a trained model to serving engine, and then destroy the trained model
|
||||
hf_model_new = AutoModelForCausalLM.from_pretrained(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
||||
torch_dtype="bfloat16",
|
||||
device_map="cuda",
|
||||
)
|
||||
gpu_memory_usage_after_loaded_hf_model = get_gpu_memory_gb()
|
||||
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
|
||||
|
||||
# destroy the hf model
|
||||
del hf_model_new
|
||||
torch.cuda.empty_cache()
|
||||
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE])
|
||||
|
||||
gpu_memory_usage_after_resume_kv_cache = get_gpu_memory_gb()
|
||||
self.assertGreater(
|
||||
gpu_memory_usage_after_resume_kv_cache,
|
||||
gpu_memory_usage_after_resume_weights,
|
||||
)
|
||||
|
||||
print(f"Resume + update took {time.perf_counter() - t:.2f}s")
|
||||
print(
|
||||
f"Memory: {gpu_memory_usage_before_resume_weights:.1f} → {gpu_memory_usage_after_resume_weights:.1f} → {gpu_memory_usage_after_loaded_hf_model:.1f} → {gpu_memory_usage_after_resume_kv_cache:.1f} GB"
|
||||
)
|
||||
|
||||
print("generate (#2)")
|
||||
outputs = engine.generate(params["prompt"], params["sampling_params"])[
|
||||
"text"
|
||||
]
|
||||
self.assertEqual(outputs, params["expect_output_after_update_weights"])
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -235,7 +235,8 @@ def _run_subprocess(
|
||||
output_writer.send(execution_ok)
|
||||
output_writer.close()
|
||||
|
||||
engine.shutdown()
|
||||
if "engine" in locals() and engine is not None:
|
||||
engine.shutdown()
|
||||
print(f"subprocess[{rank=}] end", flush=True)
|
||||
|
||||
|
||||
|
||||
@@ -249,7 +249,8 @@ def _run_subprocess(
|
||||
output_writer.send(execution_ok)
|
||||
output_writer.close()
|
||||
|
||||
engine.shutdown()
|
||||
if "engine" in locals() and engine is not None:
|
||||
engine.shutdown()
|
||||
print(f"subprocess[{rank=}] end", flush=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user