diff --git a/python/pyproject.toml b/python/pyproject.toml index 7ad60e4b2..ff3e4486f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] -torch_memory_saver = ["torch_memory_saver>=0.0.4"] +torch_memory_saver = ["torch_memory_saver>=0.0.8"] decord = ["decord"] test = [ "accelerate", diff --git a/python/sglang/srt/constants.py b/python/sglang/srt/constants.py new file mode 100644 index 000000000..aa03a089b --- /dev/null +++ b/python/sglang/srt/constants.py @@ -0,0 +1,3 @@ +# GPU Memory Types +GPU_MEMORY_TYPE_KV_CACHE = "kv_cache" +GPU_MEMORY_TYPE_WEIGHTS = "weights" diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 76b06c8ba..12781784e 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -31,6 +31,7 @@ import numpy as np import torch from torch.distributed import ProcessGroup +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll from sglang.srt.disaggregation.utils import ( FAKE_BOOTSTRAP_HOST, @@ -90,7 +91,7 @@ class DecodeReqToTokenPool: self.max_context_len = max_context_len self.device = device self.pre_alloc_size = pre_alloc_size - with memory_saver_adapter.region(): + with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE): self.req_to_token = torch.zeros( (size + pre_alloc_size, max_context_len), dtype=torch.int32, diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 34f11d4d5..4c88a5289 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -479,17 +479,15 @@ class Engine(EngineBase): self.tokenizer_manager.get_weights_by_name(obj, None) ) - def release_memory_occupation(self): - """Release GPU occupation temporarily.""" - obj = ReleaseMemoryOccupationReqInput() + def release_memory_occupation(self, tags: Optional[List[str]] = None): + obj = ReleaseMemoryOccupationReqInput(tags=tags) loop = asyncio.get_event_loop() return loop.run_until_complete( self.tokenizer_manager.release_memory_occupation(obj, None) ) - def resume_memory_occupation(self): - """Resume GPU occupation.""" - obj = ResumeMemoryOccupationReqInput() + def resume_memory_occupation(self, tags: Optional[List[str]] = None): + obj = ResumeMemoryOccupationReqInput(tags=tags) loop = asyncio.get_event_loop() return loop.run_until_complete( self.tokenizer_manager.resume_memory_occupation(obj, None) @@ -670,11 +668,9 @@ def _launch_subprocesses( scheduler_procs = [] if server_args.dp_size == 1: - # Launch tensor parallel scheduler processes memory_saver_adapter = TorchMemorySaverAdapter.create( enable=server_args.enable_memory_saver ) - scheduler_pipe_readers = [] nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) @@ -710,6 +706,7 @@ def _launch_subprocesses( writer, ), ) + with memory_saver_adapter.configure_subprocess(): proc.start() scheduler_procs.append(proc) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c94d81eb3..44fd9fa8e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput: @dataclass class ReleaseMemoryOccupationReqInput: - pass + # Optional tags to identify the memory region, which is primarily used for RL + # Currently we only support `weights` and `kv_cache` + tags: Optional[List[str]] = None @dataclass @@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput: @dataclass class ResumeMemoryOccupationReqInput: - pass + # Optional tags to identify the memory region, which is primarily used for RL + # Currently we only support `weights` and `kv_cache` + tags: Optional[List[str]] = None @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2b0dbae5f..eae5a61fc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -36,6 +36,7 @@ from torch.distributed import barrier from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constrained.base_grammar_backend import ( INVALID_GRAMMAR_OBJ, create_grammar_backend, @@ -450,8 +451,6 @@ class Scheduler( t = threading.Thread(target=self.watchdog_thread, daemon=True) t.start() self.parent_process = psutil.Process().parent() - - # Init memory saver self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=server_args.enable_memory_saver ) @@ -2227,23 +2226,40 @@ class Scheduler( return GetWeightsByNameReqOutput(parameter) def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): - self.memory_saver_adapter.check_validity( - caller_name="release_memory_occupation" - ) - self.stashed_model_static_state = _export_static_state( - self.tp_worker.worker.model_runner.model - ) - self.memory_saver_adapter.pause() - self.flush_cache() + tags = recv_req.tags + import subprocess + + if tags is None: + tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE] + + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.stashed_model_static_state = _export_static_state( + self.tp_worker.worker.model_runner.model + ) + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS) + return ReleaseMemoryOccupationReqOutput() def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput): - self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation") - self.memory_saver_adapter.resume() - _import_static_state( - self.tp_worker.worker.model_runner.model, self.stashed_model_static_state - ) - del self.stashed_model_static_state + tags = recv_req.tags + if tags is None or len(tags) == 0: + tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE] + + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) + _import_static_state( + self.tp_worker.worker.model_runner.model, + self.stashed_model_static_state, + ) + del self.stashed_model_static_state + + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + return ResumeMemoryOccupationReqOutput() def slow_down(self, recv_req: SlowDownReqInput): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index d426093df..c01807f1b 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -35,6 +35,7 @@ import torch import triton import triton.language as tl +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2 @@ -54,6 +55,7 @@ class ReqToTokenPool: device: str, enable_memory_saver: bool, ): + memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver ) @@ -61,7 +63,7 @@ class ReqToTokenPool: self.size = size self.max_context_len = max_context_len self.device = device - with memory_saver_adapter.region(): + with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): self.req_to_token = torch.zeros( (size, max_context_len), dtype=torch.int32, device=device ) @@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache): ) def _create_buffers(self): - with self.memory_saver_adapter.region(): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.enable_custom_mem_pool @@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache): else: self.custom_mem_pool = None - with self.memory_saver_adapter.region(): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.custom_mem_pool @@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache): end_layer, ) - with self.memory_saver_adapter.region(): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): # [size, head_num, head_dim] for each layer self.k_buffer = [ torch.zeros( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6ffb0aed1..e1335074a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -30,6 +30,7 @@ from sglang.srt import debug_utils from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.distributed import ( get_tp_group, get_world_group, @@ -222,6 +223,7 @@ class ModelRunner: def initialize(self, min_per_gpu_memory: float): server_args = self.server_args + self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=self.server_args.enable_memory_saver ) @@ -547,7 +549,7 @@ class ModelRunner: monkey_patch_vllm_parallel_state() monkey_patch_isinstance_for_vllm_base_layer() - with self.memory_saver_adapter.region(): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS): self.model = get_model( model_config=self.model_config, load_config=self.load_config, diff --git a/python/sglang/srt/torch_memory_saver_adapter.py b/python/sglang/srt/torch_memory_saver_adapter.py index 2b1080d25..a46151782 100644 --- a/python/sglang/srt/torch_memory_saver_adapter.py +++ b/python/sglang/srt/torch_memory_saver_adapter.py @@ -1,11 +1,13 @@ import logging +import threading +import time from abc import ABC -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext try: import torch_memory_saver - _primary_memory_saver = torch_memory_saver.TorchMemorySaver() + _memory_saver = torch_memory_saver.torch_memory_saver import_error = None except ImportError as e: import_error = e @@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC): def configure_subprocess(self): raise NotImplementedError - def region(self): + def region(self, tag: str): raise NotImplementedError - def pause(self): + def pause(self, tag: str): raise NotImplementedError - def resume(self): + def resume(self, tag: str): raise NotImplementedError @property @@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC): class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): + """Adapter for TorchMemorySaver with tag-based control""" + def configure_subprocess(self): return torch_memory_saver.configure_subprocess() - def region(self): - return _primary_memory_saver.region() + def region(self, tag: str): + return _memory_saver.region(tag=tag) - def pause(self): - return _primary_memory_saver.pause() + def pause(self, tag: str): + return _memory_saver.pause(tag=tag) - def resume(self): - return _primary_memory_saver.resume() + def resume(self, tag: str): + return _memory_saver.resume(tag=tag) @property def enabled(self): - return _primary_memory_saver.enabled + return _memory_saver is not None and _memory_saver.enabled class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): @@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): yield @contextmanager - def region(self): + def region(self, tag: str): yield - def pause(self): + def pause(self, tag: str): pass - def resume(self): + def resume(self, tag: str): pass @property diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 9e3011dd8..7c335c79d 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback # General test models DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" +DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B" diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 58142e73b..42e52de4b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -74,7 +74,6 @@ suites = { TestFile("test_radix_attention.py", 105), TestFile("test_reasoning_content.py", 89), TestFile("test_regex_constrained.py", 64), - TestFile("test_release_memory_occupation.py", 44), TestFile("test_request_length_validation.py", 31), TestFile("test_retract_decode.py", 54), TestFile("test_server_args.py", 1), @@ -146,6 +145,7 @@ suites = { TestFile("test_patch_torch.py", 19), TestFile("test_update_weights_from_distributed.py", 103), TestFile("test_verl_engine_2_gpu.py", 64), + TestFile("test_release_memory_occupation.py", 44), ], "per-commit-2-gpu-amd": [ TestFile("models/lora/test_lora_tp.py", 116), diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py index 7a7659280..eb20fc46b 100644 --- a/test/srt/test_release_memory_occupation.py +++ b/test/srt/test_release_memory_occupation.py @@ -1,3 +1,32 @@ +"""Test memory release and resume operations for SGLang engine in hybrid RL training. + +This test suite evaluates the SGLang engine's memory management capabilities, focusing +on releasing and resuming memory occupation for KV cache and model weights. It simulates +an RL workflow where the SGLang engine acts as a rollout engine for experience collection. +The process involves initializing the engine, sending a small number of requests to simulate +rollout, releasing memory to mimic offloading during RL training, resuming memory occupation, +updating weights with a trained HuggingFace model, and verifying the updated weights. + +Detailed in our proposal (https://github.com/sgl-project/sglang/pull/7099), two test cases +are included: + +1. Basic Release and Resume: Uses a lower mem_fraction_static (0.6) to control memory allocation +and avoid OOM errors carefully. This test simulates a scenario without multi-stage memory management, +ensuring the engine can release and resume memory occupation while maintaining functionality after +weight updates. + +2. Multi-Stage Release and Resume: Employs a higher mem_fraction_static (0.85) to simulate higher +memory pressure, leveraging multi-stage memory management. It sequentially releases and resumes +KV cache and model weights, verifying memory deallocation and reallocation at each stage, and +ensuring correct weight updates and text generation. + +3. Tensor Parallel Tests: Tests memory release and resume operations with different tensor parallel +configurations (tp=1, tp=2) to ensure proper memory management in distributed settings. For different +data parallel size, we test it in verl. +""" + +import gc +import os import time import unittest @@ -5,93 +34,221 @@ import torch from transformers import AutoModelForCausalLM import sglang as sgl -from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, + CustomTestCase, +) # (temporarily) set to true to observe memory usage in nvidia-smi more clearly -_DEBUG_EXTRA = True +_DEBUG_EXTRA = False + + +def get_gpu_memory_gb(): + return torch.cuda.device_memory_used() / 1024**3 class TestReleaseMemoryOccupation(CustomTestCase): - def test_release_and_resume_occupation(self): - prompt = "Today is a sunny day and I like" - sampling_params = {"temperature": 0, "max_new_tokens": 8} - model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - expect_output = " to spend it outdoors. I decided to" - + def _setup_engine(self, model_name, mem_fraction_static=0.8, tp_size=1): + """Common setup for engine and HF model.""" engine = sgl.Engine( model_path=model_name, random_seed=42, enable_memory_saver=True, + mem_fraction_static=mem_fraction_static, + tp_size=tp_size, # disable_cuda_graph=True, # for debugging only ) - hf_model_new = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype="bfloat16" - ) + return engine + + def _common_test_params(self): + """Common test parameters.""" + return { + "prompt": "Today is a sunny day and I like", + "sampling_params": {"temperature": 0, "max_new_tokens": 8}, + "expect_output_before_update_weights": " to spend it outdoors. I decided to", + "expect_output_after_update_weights": " to go for a walk. I like", + } + + def _test_initial_generation( + self, engine, prompt, sampling_params, expect_output_before_update_weights + ): + """Test initial generation and memory allocation.""" print("generate (#1)") outputs = engine.generate(prompt, sampling_params)["text"] - self.assertEqual(outputs, expect_output) + self.assertEqual(outputs, expect_output_before_update_weights) if _DEBUG_EXTRA: time.sleep(3) - self.assertEqual( - _try_allocate_big_tensor(), - False, - "Should not be able to allocate big tensors before releasing", - ) + def test_release_and_resume_occupation(self): + # Without multi-stage release and resume, we need to carefully control the memory fraction to avoid OOM + model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + assert ( + torch.cuda.device_count() >= 2 + ), "Need at least 2 GPUs for tensor parallel tests" - print("release_memory_occupation start") - t = time.perf_counter() - engine.release_memory_occupation() - if _DEBUG_EXTRA: - print("release_memory_occupation", time.perf_counter() - t) + for tp_size in [1, 2]: - if _DEBUG_EXTRA: - time.sleep(5) + print(f"Testing tp_size={tp_size} for test_release_and_resume_occupation") + engine = self._setup_engine( + model_name=model_name, mem_fraction_static=0.6, tp_size=tp_size + ) + params = self._common_test_params() - self.assertEqual( - _try_allocate_big_tensor(), - True, - "Should be able to allocate big tensors aftre releasing", - ) + self._test_initial_generation( + engine, + params["prompt"], + params["sampling_params"], + params["expect_output_before_update_weights"], + ) - if _DEBUG_EXTRA: - time.sleep(5) + t = time.perf_counter() + gpu_memory_usage_before_release = get_gpu_memory_gb() + engine.release_memory_occupation() + gpu_memory_usage_after_release = get_gpu_memory_gb() - print("resume_memory_occupation start") - t = time.perf_counter() - engine.resume_memory_occupation() - if _DEBUG_EXTRA: - print("resume_memory_occupation", time.perf_counter() - t) + self.assertLess( + gpu_memory_usage_after_release, + gpu_memory_usage_before_release, + ) - self.assertEqual( - _try_allocate_big_tensor(), - False, - "Should not be able to allocate big tensors after resuming", - ) + print( + f"Release took {time.perf_counter() - t:.2f}s, memory: {gpu_memory_usage_before_release:.1f} GB → {gpu_memory_usage_after_release:.1f} GB" + ) - print("update_weights_from_tensor") - # As if: PPO has updated hf model's weights, and now we sync it to SGLang - engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) + if _DEBUG_EXTRA: + time.sleep(3) - print("generate (#2)") - outputs = engine.generate(prompt, sampling_params)["text"] - self.assertEqual(outputs, expect_output) + t = time.perf_counter() + engine.resume_memory_occupation() + print( + f"Resume took {time.perf_counter() - t:.2f}s, memory: {get_gpu_memory_gb():.1f} GB" + ) - if _DEBUG_EXTRA: - time.sleep(4) + hf_model_new = AutoModelForCausalLM.from_pretrained( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, + torch_dtype="bfloat16", + device_map="cuda", + ) + engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) - engine.shutdown() + # destroy the hf model + del hf_model_new + torch.cuda.empty_cache() + print("generate (#2)") + outputs = engine.generate(params["prompt"], params["sampling_params"])[ + "text" + ] + self.assertEqual(outputs, params["expect_output_after_update_weights"]) + engine.shutdown() -def _try_allocate_big_tensor(size: int = 20_000_000_000): - try: - torch.empty((size,), dtype=torch.uint8, device="cuda") - torch.cuda.empty_cache() - return True - except torch.cuda.OutOfMemoryError: - return False + def test_multi_stage_release_and_resume(self): + # With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM + model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + for tp_size in [1, 2]: + if tp_size == 2 and torch.cuda.device_count() < 2: + continue + + print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume") + engine = sgl.Engine( + model_path=model_name, + random_seed=42, + enable_memory_saver=True, + mem_fraction_static=0.85, # Higher memory pressure + tp_size=tp_size, + ) + params = self._common_test_params() + + self._test_initial_generation( + engine, + params["prompt"], + params["sampling_params"], + params["expect_output_before_update_weights"], + ) + + t = time.perf_counter() + gpu_memory_usage_before_release_kv_cache = get_gpu_memory_gb() + engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE]) + + gpu_memory_usage_after_release_kv_cache = get_gpu_memory_gb() + + self.assertLess( + gpu_memory_usage_after_release_kv_cache, + gpu_memory_usage_before_release_kv_cache, + ) + engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + + gpu_memory_usage_after_release_weights = get_gpu_memory_gb() + + self.assertLess( + gpu_memory_usage_after_release_weights, + gpu_memory_usage_after_release_kv_cache, + ) + + print(f"Release took {time.perf_counter() - t:.2f}s") + print( + f"Memory: {gpu_memory_usage_before_release_kv_cache:.1f} → {gpu_memory_usage_after_release_kv_cache:.1f} → {gpu_memory_usage_after_release_weights:.1f} GB" + ) + + if _DEBUG_EXTRA: + time.sleep(3) + + t = time.perf_counter() + gpu_memory_usage_before_resume_weights = get_gpu_memory_gb() + + # gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume_weights should be close + + self.assertAlmostEqual( + gpu_memory_usage_after_release_weights, + gpu_memory_usage_before_resume_weights, + delta=3.0, + ) + print(f"Resume weights took {time.perf_counter() - t:.2f}s") + + engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + gpu_memory_usage_after_resume_weights = get_gpu_memory_gb() + + self.assertGreater( + gpu_memory_usage_after_resume_weights, + gpu_memory_usage_before_resume_weights, + ) + + # Update weights from a trained model to serving engine, and then destroy the trained model + hf_model_new = AutoModelForCausalLM.from_pretrained( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, + torch_dtype="bfloat16", + device_map="cuda", + ) + gpu_memory_usage_after_loaded_hf_model = get_gpu_memory_gb() + engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) + + # destroy the hf model + del hf_model_new + torch.cuda.empty_cache() + engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE]) + + gpu_memory_usage_after_resume_kv_cache = get_gpu_memory_gb() + self.assertGreater( + gpu_memory_usage_after_resume_kv_cache, + gpu_memory_usage_after_resume_weights, + ) + + print(f"Resume + update took {time.perf_counter() - t:.2f}s") + print( + f"Memory: {gpu_memory_usage_before_resume_weights:.1f} → {gpu_memory_usage_after_resume_weights:.1f} → {gpu_memory_usage_after_loaded_hf_model:.1f} → {gpu_memory_usage_after_resume_kv_cache:.1f} GB" + ) + + print("generate (#2)") + outputs = engine.generate(params["prompt"], params["sampling_params"])[ + "text" + ] + self.assertEqual(outputs, params["expect_output_after_update_weights"]) + engine.shutdown() if __name__ == "__main__": diff --git a/test/srt/test_verl_engine_2_gpu.py b/test/srt/test_verl_engine_2_gpu.py index aaf9ac460..40321ee3f 100644 --- a/test/srt/test_verl_engine_2_gpu.py +++ b/test/srt/test_verl_engine_2_gpu.py @@ -235,7 +235,8 @@ def _run_subprocess( output_writer.send(execution_ok) output_writer.close() - engine.shutdown() + if "engine" in locals() and engine is not None: + engine.shutdown() print(f"subprocess[{rank=}] end", flush=True) diff --git a/test/srt/test_verl_engine_4_gpu.py b/test/srt/test_verl_engine_4_gpu.py index d620f0e44..014f17daf 100644 --- a/test/srt/test_verl_engine_4_gpu.py +++ b/test/srt/test_verl_engine_4_gpu.py @@ -249,7 +249,8 @@ def _run_subprocess( output_writer.send(execution_ok) output_writer.close() - engine.shutdown() + if "engine" in locals() and engine is not None: + engine.shutdown() print(f"subprocess[{rank=}] end", flush=True)