diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cf81ab7c4..c38a6f4d5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -160,6 +160,7 @@ suites = { "per-commit-4-gpu": [ TestFile("test_local_attn.py", 250), TestFile("test_pp_single_node.py", 150), + TestFile("test_multi_instance_release_memory_occupation.py", 64), ], "per-commit-4-gpu-amd": [ TestFile("test_pp_single_node.py", 150), diff --git a/test/srt/test_multi_instance_release_memory_occupation.py b/test/srt/test_multi_instance_release_memory_occupation.py new file mode 100644 index 000000000..e4e8d9081 --- /dev/null +++ b/test/srt/test_multi_instance_release_memory_occupation.py @@ -0,0 +1,250 @@ +import multiprocessing +import os +import subprocess +import traceback +import unittest +from multiprocessing import Process +from typing import Iterable, Tuple + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import init_device_mesh +from transformers import AutoModelForCausalLM + +from sglang.srt.entrypoints.engine import Engine as SglangEngine +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, + CustomTestCase, + find_available_port, +) + +TEST_SUITE = dict( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + mem_fraction_static=0.85, + dp_size=2, + tp_size=2, +) + + +class EngineWrapper: + """ + A wrapper around Sglang engine to mock multi instance cases such as RL traing. + + """ + + def __init__( + self, model_path, random_seed, mem_fraction_static, device_mesh_cpu, base_gpu_id + ): + self._device_mesh_cpu = device_mesh_cpu + self._tp_rank = device_mesh_cpu.get_local_rank() + self._rank = device_mesh_cpu.get_rank() + self._tp_size = device_mesh_cpu.size() + tp_size_per_node = self._tp_size + node_rank = self._tp_rank // tp_size_per_node + first_rank_in_node = self._tp_rank % tp_size_per_node == 0 + engine_kwargs = dict( + model_path=model_path, + random_seed=random_seed, + mem_fraction_static=mem_fraction_static, + base_gpu_id=base_gpu_id, + enable_memory_saver=True, + tp_size=self._tp_size, + node_rank=node_rank, + nnodes=1, + ) + self._engine = None + if first_rank_in_node: + os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" + self._engine = SglangEngine(**engine_kwargs) + + dist.barrier(group=self._device_mesh_cpu.get_group()) + + def update_weights_from_tensor( + self, named_tensors: Iterable[Tuple[str, torch.Tensor]] + ): + if self._tp_rank == 0: + self._engine.update_weights_from_tensor(list(named_tensors)) + self._engine.flush_cache() + dist.barrier(group=self._device_mesh_cpu.get_group()) + + def release_memory_occupation(self, tags): + if self._tp_rank == 0: + self._engine.release_memory_occupation(tags) + dist.barrier(group=self._device_mesh_cpu.get_group()) + + def resume_memory_occupation(self, tags): + if self._tp_rank == 0: + self._engine.resume_memory_occupation(tags) + dist.barrier(group=self._device_mesh_cpu.get_group()) + + def shutdown(self): + if self._tp_rank == 0: + self._engine.shutdown() + dist.barrier(group=self._device_mesh_cpu.get_group()) + + +def get_gpu_memory_gb(gpu_id=0): + return torch.cuda.device_memory_used() / 1024**3 + + +class TestMultiInstanceReleaseMemoryOccupation(CustomTestCase): + @classmethod + def setUpClass(cls): + multiprocessing.set_start_method("spawn") + + def test_multi_instance_release_memory_occupation(self): + master_port = find_available_port(23456) + + dp_size = TEST_SUITE["dp_size"] + tp_size = TEST_SUITE["tp_size"] + world_size = dp_size * tp_size + processes = [] + output_reader, output_writer = multiprocessing.Pipe(duplex=False) + for rank in range(world_size): + p = Process( + target=_run_sglang_subprocess, + kwargs=dict( + rank=rank, + dp_size=dp_size, + tp_size=tp_size, + model_path=TEST_SUITE["model_path"], + master_port=master_port, + output_writer=output_writer, + mem_fraction_static=TEST_SUITE["mem_fraction_static"], + ), + ) + p.start() + processes.append(p) + + for _ in range(world_size): + self.assertTrue( + output_reader.recv(), f"Subprocess fail. Check the logs above." + ) + for p in processes: + p.join() + + +def _run_sglang_subprocess( + rank: int, + dp_size: int, + tp_size: int, + model_path: str, + master_port: int, + output_writer, + mem_fraction_static: float, +): + engine = None + try: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group( + rank=rank, + device_id=torch.device(f"cuda:{rank}"), + world_size=dp_size * tp_size, + ) + torch.cuda.set_device(rank) + + base_gpu_id = rank // tp_size * tp_size + mesh_kwargs = dict( + mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"] + ) + inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs) + inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs) + print( + f"subprocess[{rank=},{base_gpu_id=},{rank=},{tp_size=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}" + ) + + _mem_usage = get_gpu_memory_gb(rank) + print(f"GPU{rank} Memory usage before starting Engine: {_mem_usage}") + + engine = EngineWrapper( + model_path=model_path, + random_seed=42, + mem_fraction_static=mem_fraction_static, + device_mesh_cpu=inference_device_mesh_cpu["tp"], + base_gpu_id=base_gpu_id, + ) + print(f"subprocess[{rank=}] {engine=}", flush=True) + + # 1 - release kv cache + _mem_usage = get_gpu_memory_gb(rank) + print(f"GPU{rank} Memory usage before releasing Sgl KV cache: {_mem_usage}") + engine.release_memory_occupation(tags=["kv_cache"]) + _curr_usage = get_gpu_memory_gb(rank) + assert ( + _curr_usage < _mem_usage + ), f"Memory usage after releasing KV cache must be reduced! before: {_mem_usage} vs after: {_curr_usage}" + + # 2 - release sglang weights + _mem_usage = get_gpu_memory_gb(rank) + print(f"GPU{rank} Memory usage before releasing Sgl weights: {_mem_usage}") + engine.release_memory_occupation(tags=["weights"]) + + _curr_usage = get_gpu_memory_gb(rank) + assert ( + _curr_usage < _mem_usage + ), f"Memory usage after releasing weights must be reduced! before: {_mem_usage} vs after: {_curr_usage}" + + # 3 - load hf model + _mem_usage = get_gpu_memory_gb(rank) + print( + f"GPU{rank} Memory usage after releasing Sgl weights and kv cache: {_mem_usage}" + ) + hf_model = AutoModelForCausalLM.from_pretrained( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, + torch_dtype="bfloat16", + device_map=f"cuda:{rank}", + trust_remote_code=True, + ).cuda() + _curr_usage = get_gpu_memory_gb(rank) + assert ( + _curr_usage > _mem_usage + ), f"Memory usage after loading hf model must be increased! before: {_mem_usage} vs after: {_curr_usage}" + + # 4 - resume sglang weights and update the weights + _mem_usage = get_gpu_memory_gb(rank) + print(f"GPU{rank} Memory usage after loading hf model: {_mem_usage}") + engine.resume_memory_occupation(tags=["weights"]) + engine.update_weights_from_tensor( + named_tensors=list(hf_model.named_parameters()) + ) + + # 5 - release hf model + _mem_usage = get_gpu_memory_gb(rank) + print(f"GPU{rank} Memory usage after resuming Sgl weights: {_mem_usage}") + del hf_model + torch.cuda.empty_cache() + _curr_usage = get_gpu_memory_gb(rank) + assert ( + _curr_usage < _mem_usage + ), f"Memory usage after releasing hf model must be reduced! before: {_mem_usage} vs after: {_curr_usage}" + + # 6 - resume slgang kv cache + _mem_usage = get_gpu_memory_gb(rank) + print(f"GPU{rank} Memory usage after releasing hf model: {_mem_usage}") + engine.resume_memory_occupation(tags=["kv_cache"]) + _curr_usage = get_gpu_memory_gb(rank) + assert ( + _curr_usage > _mem_usage + ), f"Memory usage after resuming kv cache must be increased! before: {_mem_usage} vs after: {_curr_usage}" + + # 7 - Final checking! + _mem_usage = get_gpu_memory_gb(rank) + print(f"GPU{rank} Memory usage after resuming Sgl KV cache: {_mem_usage}") + + execution_ok = True + except Exception as e: + print(f"subprocess[{rank=}] has error: {e}", flush=True) + traceback.print_exc() + execution_ok = False + + output_writer.send(execution_ok) + output_writer.close() + + if engine: + engine.shutdown() + + +if __name__ == "__main__": + unittest.main()