diff --git a/python/pyproject.toml b/python/pyproject.toml index 4b627ae94..61a36e341 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -44,6 +44,7 @@ srt_hpu = ["sglang[runtime_common]"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] +torch_memory_saver = ["torch_memory_saver"] test = [ "jsonlines", "matplotlib", diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 26b8921c4..ec45696bf 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). import uuid from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Tuple, Union - -import torch +from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams @@ -459,6 +457,26 @@ class GetWeightsByNameReqOutput: parameter: list +@dataclass +class ReleaseMemoryOccupationReqInput: + pass + + +@dataclass +class ReleaseMemoryOccupationReqOutput: + pass + + +@dataclass +class ResumeMemoryOccupationReqInput: + pass + + +@dataclass +class ResumeMemoryOccupationReqOutput: + pass + + @dataclass class AbortReq: # The request id diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1c07ea6ad..b9e74aa9d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import ( OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, @@ -88,6 +92,7 @@ from sglang.srt.utils import ( set_random_seed, suppress_other_loggers, ) +from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -357,6 +362,10 @@ class Scheduler: t.start() self.parent_process = psutil.Process().parent() + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + # Init profiler if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": self.profiler = None @@ -519,6 +528,12 @@ class Scheduler: elif isinstance(recv_req, GetWeightsByNameReqInput): parameter = self.get_weights_by_name(recv_req) self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) + elif isinstance(recv_req, ReleaseMemoryOccupationReqInput): + self.release_memory_occupation() + self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput()) + elif isinstance(recv_req, ResumeMemoryOccupationReqInput): + self.resume_memory_occupation() + self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput()) elif isinstance(recv_req, ProfileReq): if recv_req == ProfileReq.START_PROFILE: self.start_profile() @@ -1538,6 +1553,20 @@ class Scheduler: parameter = self.tp_worker.get_weights_by_name(recv_req) return parameter + def release_memory_occupation(self): + self.stashed_model_static_state = _export_static_state( + self.tp_worker.worker.model_runner.model + ) + self.memory_saver_adapter.pause() + self.flush_cache() + + def resume_memory_occupation(self): + self.memory_saver_adapter.resume() + _import_static_state( + self.tp_worker.worker.model_runner.model, self.stashed_model_static_state + ) + del self.stashed_model_static_state + def start_profile(self) -> None: if self.profiler is None: raise RuntimeError("Profiler is not enabled.") @@ -1576,6 +1605,20 @@ class Scheduler: del self.sessions[session_id] +def _export_static_state(model): + return dict( + buffers=[ + (name, buffer.detach().clone()) for name, buffer in model.named_buffers() + ] + ) + + +def _import_static_state(model, static_params): + self_named_buffers = dict(model.named_buffers()) + for name, tensor in static_params["buffers"]: + self_named_buffers[name][...] = tensor + + def run_scheduler_process( server_args: ServerArgs, port_args: PortArgs, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index fb6202932..33968e34f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -53,6 +53,10 @@ from sglang.srt.managers.io_struct import ( OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, SessionParams, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -188,6 +192,12 @@ class TokenizerManager: self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.release_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.resume_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) # Metrics if self.enable_metrics: @@ -548,6 +558,22 @@ class TokenizerManager: else: return all_parameters + async def release_memory_occupation( + self, + obj: ReleaseMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.release_memory_occupation_communicator(obj) + + async def resume_memory_occupation( + self, + obj: ResumeMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.resume_memory_occupation_communicator(obj) + async def open_session( self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None ): @@ -627,6 +653,8 @@ class TokenizerManager: UpdateWeightsFromDistributedReqOutput, GetWeightsByNameReqOutput, InitWeightsUpdateGroupReqOutput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqOutput, ] = await self.recv_from_detokenizer.recv_pyobj() if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)): @@ -709,6 +737,10 @@ class TokenizerManager: self.update_weights_from_tensor_communicator.handle_recv(recv_obj) elif isinstance(recv_obj, GetWeightsByNameReqOutput): self.get_weights_by_name_communicator.handle_recv(recv_obj) + elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput): + self.release_memory_occupation_communicator.handle_recv(recv_obj) + elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput): + self.resume_memory_occupation_communicator.handle_recv(recv_obj) else: raise ValueError(f"Invalid object: {recv_obj=}") diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index abee7764b..0761169e4 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter + """ Memory pool. @@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024 class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" - def __init__(self, size: int, max_context_len: int, device: str, use_records: bool): + def __init__( + self, + size: int, + max_context_len: int, + device: str, + use_records: bool, + enable_memory_saver: bool, + ): + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + self.size = size self.max_context_len = max_context_len self.device = device - self.req_to_token = torch.zeros( - (size, max_context_len), dtype=torch.int32, device=device - ) + with memory_saver_adapter.region(): + self.req_to_token = torch.zeros( + (size, max_context_len), dtype=torch.int32, device=device + ) self.free_slots = list(range(size)) self.write_records = [] self.use_records = use_records @@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool): head_dim: int, layer_num: int, device: str, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) + + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + self.head_num = head_num self.head_dim = head_dim self.layer_num = layer_num @@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool): ) def _create_buffers(self): - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - torch.empty( - (self.size + 1, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - self.v_buffer = [ - torch.empty( - (self.size + 1, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] + with self.memory_saver_adapter.region(): + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.k_buffer = [ + torch.empty( + (self.size + 1, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.empty( + (self.size + 1, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] def _clear_buffers(self): del self.k_buffer @@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool): qk_rope_head_dim: int, layer_num: int, device: str, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) self.kv_lora_rank = kv_lora_rank - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.kv_buffer = [ - torch.empty( - (size + 1, 1, kv_lora_rank + qk_rope_head_dim), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] + + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + + with memory_saver_adapter.region(): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.kv_buffer = [ + torch.empty( + (size + 1, 1, kv_lora_rank + qk_rope_head_dim), + dtype=self.store_dtype, + device=device, + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): if self.store_dtype != self.dtype: @@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): layer_num: int, device: str, heavy_channel_num: int, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) - # [size, head_num, head_dim] for each layer - self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) - for _ in range(layer_num) - ] - self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) - for _ in range(layer_num) - ] + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) - # [size, head_num, heavy_channel_num] for each layer - self.label_buffer = [ - torch.empty( - (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device - ) - for _ in range(layer_num) - ] + with memory_saver_adapter.region(): + # [size, head_num, head_dim] for each layer + self.k_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + for _ in range(layer_num) + ] + self.v_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + for _ in range(layer_num) + ] + + # [size, head_num, heavy_channel_num] for each layer + self.label_buffer = [ + torch.empty( + (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): return self.k_buffer[layer_id] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d46a2c0dc..190427649 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -60,6 +60,7 @@ from sglang.srt.utils import ( monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) +from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter logger = logging.getLogger(__name__) @@ -166,6 +167,10 @@ class ModelRunner: # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=self.server_args.enable_memory_saver + ) + # Load the model self.sampler = Sampler() self.load_model() @@ -272,11 +277,12 @@ class ModelRunner: monkey_patch_vllm_gguf_config() # Load the model - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=DeviceConfig(self.device), - ) + with self.memory_saver_adapter.region(): + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + ) if self.server_args.kv_cache_dtype == "fp8_e4m3": if self.server_args.quantization_param_path is not None: @@ -417,7 +423,7 @@ class ModelRunner: logger.info( f"init custom process group: master_address={master_address}, master_port={master_port}, " - f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}" + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}" ) try: @@ -590,6 +596,7 @@ class ModelRunner: max_context_len=self.model_config.context_len + 4, device=self.device, use_records=False, + enable_memory_saver=self.server_args.enable_memory_saver, ) if ( self.model_config.attention_arch == AttentionArch.MLA @@ -602,6 +609,7 @@ class ModelRunner: qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, ) elif self.server_args.enable_double_sparsity: self.token_to_kv_pool = DoubleSparseTokenToKVPool( @@ -612,6 +620,7 @@ class ModelRunner: layer_num=self.model_config.num_hidden_layers, device=self.device, heavy_channel_num=self.server_args.ds_heavy_channel_num, + enable_memory_saver=self.server_args.enable_memory_saver, ) else: self.token_to_kv_pool = MHATokenToKVPool( @@ -621,6 +630,7 @@ class ModelRunner: head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, ) logger.info( f"Memory pool end. " diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index fa1625b09..4e837e538 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union import torch +from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -57,6 +59,8 @@ from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, OpenSessionReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, @@ -255,6 +259,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): return _create_error_response(e) +@app.api_route("/release_memory_occupation", methods=["GET", "POST"]) +async def release_memory_occupation( + obj: ReleaseMemoryOccupationReqInput, request: Request +): + """Release GPU occupation temporarily""" + try: + await tokenizer_manager.release_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) +async def resume_memory_occupation( + obj: ResumeMemoryOccupationReqInput, request: Request +): + """Resume GPU occupation""" + try: + await tokenizer_manager.resume_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + @app.api_route("/open_session", methods=["GET", "POST"]) async def open_session(obj: OpenSessionReqInput, request: Request): """Open a session, and return its unique session id.""" @@ -438,6 +464,10 @@ def launch_engine( server_args.model_path, server_args.tokenizer_path ) + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + if server_args.dp_size == 1: # Launch tensor parallel scheduler processes scheduler_procs = [] @@ -454,7 +484,8 @@ def launch_engine( target=run_scheduler_process, args=(server_args, port_args, gpu_id, tp_rank, None, writer), ) - proc.start() + with memory_saver_adapter.configure_subprocess(): + proc.start() scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) @@ -471,7 +502,8 @@ def launch_engine( target=run_data_parallel_controller_process, args=(server_args, port_args, writer), ) - proc.start() + with memory_saver_adapter.configure_subprocess(): + proc.start() # Launch detokenizer process detoken_proc = mp.Process( @@ -897,6 +929,18 @@ class Engine: loop = asyncio.get_event_loop() return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None)) + def release_memory_occupation(self): + """Release GPU occupation temporarily""" + obj = ReleaseMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None)) + + def resume_memory_occupation(self): + """Resume GPU occupation""" + obj = ResumeMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None)) + class Runtime: """ diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index be85a3670..4f44d5c87 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -23,7 +23,6 @@ from typing import List, Optional import torch from sglang.srt.hf_transformers_utils import check_gguf_file -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_hpu_memory_capacity, @@ -157,6 +156,7 @@ class ServerArgs: triton_attention_num_kv_splits: int = 8 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False + enable_memory_saver: bool = False def __post_init__(self): # Set missing default values @@ -854,6 +854,11 @@ class ServerArgs: action="store_true", help="Delete the model checkpoint after loading the model.", ) + parser.add_argument( + "--enable-memory-saver", + action="store_true", + help="Allow saving memory using release_memory_occupation and resume_memory_occupation", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/torch_memory_saver_adapter.py b/python/sglang/torch_memory_saver_adapter.py new file mode 100644 index 000000000..31f8ebf2f --- /dev/null +++ b/python/sglang/torch_memory_saver_adapter.py @@ -0,0 +1,59 @@ +from abc import ABC +from contextlib import contextmanager + +try: + import torch_memory_saver + + _primary_memory_saver = torch_memory_saver.TorchMemorySaver() +except ImportError: + pass + + +class TorchMemorySaverAdapter(ABC): + @staticmethod + def create(enable: bool): + return ( + _TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop() + ) + + def configure_subprocess(self): + raise NotImplementedError + + def region(self): + raise NotImplementedError + + def pause(self): + raise NotImplementedError + + def resume(self): + raise NotImplementedError + + +class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): + def configure_subprocess(self): + return torch_memory_saver.configure_subprocess() + + def region(self): + return _primary_memory_saver.region() + + def pause(self): + return _primary_memory_saver.pause() + + def resume(self): + return _primary_memory_saver.resume() + + +class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self): + yield + + def pause(self): + pass + + def resume(self): + pass diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 26c34879e..66b113f61 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -12,8 +12,9 @@ bash "${SCRIPT_DIR}/killall_sglang.sh" pip install --upgrade pip pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ -# Force reinstall flashinfer +# Force reinstall flashinfer and torch_memory_saver pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps +pip install torch_memory_saver --force-reinstall pip install transformers==4.45.2 sentence_transformers accelerate peft diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d617fcf69..658b3d2f8 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -29,6 +29,7 @@ suites = { "test_openai_server.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", + "test_release_memory_occupation.py", "test_retract_decode.py", "test_server_args.py", "test_session_control.py", diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py new file mode 100644 index 000000000..c84b64e77 --- /dev/null +++ b/test/srt/test_release_memory_occupation.py @@ -0,0 +1,98 @@ +import time +import unittest + +import torch +from transformers import AutoModelForCausalLM + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + +# (temporarily) set to true to observe memory usage in nvidia-smi more clearly +_DEBUG_EXTRA = True + + +class TestReleaseMemoryOccupation(unittest.TestCase): + def test_release_and_resume_occupation(self): + prompt = "Today is a sunny day and I like" + sampling_params = {"temperature": 0, "max_new_tokens": 8} + model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + expect_output = " to spend it outdoors. I decided to" + + engine = sgl.Engine( + model_path=model_name, + random_seed=42, + enable_memory_saver=True, + # disable_cuda_graph=True, # for debugging only + ) + hf_model_new = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="bfloat16" + ) + + print("generate (#1)") + outputs = engine.generate(prompt, sampling_params)["text"] + self.assertEqual(outputs, expect_output) + + if _DEBUG_EXTRA: + time.sleep(3) + + self.assertEqual( + _try_allocate_big_tensor(), + False, + "Should not be able to allocate big tensors before releasing", + ) + + print("release_memory_occupation start") + t = time.time() + engine.release_memory_occupation() + if _DEBUG_EXTRA: + print("release_memory_occupation", time.time() - t) + + if _DEBUG_EXTRA: + time.sleep(5) + + self.assertEqual( + _try_allocate_big_tensor(), + True, + "Should be able to allocate big tensors aftre releasing", + ) + + if _DEBUG_EXTRA: + time.sleep(5) + + print("resume_memory_occupation start") + t = time.time() + engine.resume_memory_occupation() + if _DEBUG_EXTRA: + print("resume_memory_occupation", time.time() - t) + + self.assertEqual( + _try_allocate_big_tensor(), + False, + "Should not be able to allocate big tensors after resuming", + ) + + print("update_weights_from_tensor") + # As if: PPO has updated hf model's weights, and now we sync it to SGLang + engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) + + print("generate (#2)") + outputs = engine.generate(prompt, sampling_params)["text"] + self.assertEqual(outputs, expect_output) + + if _DEBUG_EXTRA: + time.sleep(4) + + engine.shutdown() + + +def _try_allocate_big_tensor(size: int = 20_000_000_000): + try: + torch.empty((size,), dtype=torch.uint8, device="cuda") + torch.cuda.empty_cache() + return True + except torch.cuda.OutOfMemoryError: + return False + + +if __name__ == "__main__": + unittest.main()