Files
sglang/test/srt/test_multi_instance_release_memory_occupation.py

251 lines
8.5 KiB
Python

import multiprocessing
import os
import subprocess
import traceback
import unittest
from multiprocessing import Process
from typing import Iterable, Tuple
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.engine import Engine as SglangEngine
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
CustomTestCase,
find_available_port,
)
TEST_SUITE = dict(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
mem_fraction_static=0.85,
dp_size=2,
tp_size=2,
)
class EngineWrapper:
"""
A wrapper around Sglang engine to mock multi instance cases such as RL traing.
"""
def __init__(
self, model_path, random_seed, mem_fraction_static, device_mesh_cpu, base_gpu_id
):
self._device_mesh_cpu = device_mesh_cpu
self._tp_rank = device_mesh_cpu.get_local_rank()
self._rank = device_mesh_cpu.get_rank()
self._tp_size = device_mesh_cpu.size()
tp_size_per_node = self._tp_size
node_rank = self._tp_rank // tp_size_per_node
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
engine_kwargs = dict(
model_path=model_path,
random_seed=random_seed,
mem_fraction_static=mem_fraction_static,
base_gpu_id=base_gpu_id,
enable_memory_saver=True,
tp_size=self._tp_size,
node_rank=node_rank,
nnodes=1,
)
self._engine = None
if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = SglangEngine(**engine_kwargs)
dist.barrier(group=self._device_mesh_cpu.get_group())
def update_weights_from_tensor(
self, named_tensors: Iterable[Tuple[str, torch.Tensor]]
):
if self._tp_rank == 0:
self._engine.update_weights_from_tensor(list(named_tensors))
self._engine.flush_cache()
dist.barrier(group=self._device_mesh_cpu.get_group())
def release_memory_occupation(self, tags):
if self._tp_rank == 0:
self._engine.release_memory_occupation(tags)
dist.barrier(group=self._device_mesh_cpu.get_group())
def resume_memory_occupation(self, tags):
if self._tp_rank == 0:
self._engine.resume_memory_occupation(tags)
dist.barrier(group=self._device_mesh_cpu.get_group())
def shutdown(self):
if self._tp_rank == 0:
self._engine.shutdown()
dist.barrier(group=self._device_mesh_cpu.get_group())
def get_gpu_memory_gb(gpu_id=0):
return torch.cuda.device_memory_used() / 1024**3
class TestMultiInstanceReleaseMemoryOccupation(CustomTestCase):
@classmethod
def setUpClass(cls):
multiprocessing.set_start_method("spawn")
def test_multi_instance_release_memory_occupation(self):
master_port = find_available_port(23456)
dp_size = TEST_SUITE["dp_size"]
tp_size = TEST_SUITE["tp_size"]
world_size = dp_size * tp_size
processes = []
output_reader, output_writer = multiprocessing.Pipe(duplex=False)
for rank in range(world_size):
p = Process(
target=_run_sglang_subprocess,
kwargs=dict(
rank=rank,
dp_size=dp_size,
tp_size=tp_size,
model_path=TEST_SUITE["model_path"],
master_port=master_port,
output_writer=output_writer,
mem_fraction_static=TEST_SUITE["mem_fraction_static"],
),
)
p.start()
processes.append(p)
for _ in range(world_size):
self.assertTrue(
output_reader.recv(), f"Subprocess fail. Check the logs above."
)
for p in processes:
p.join()
def _run_sglang_subprocess(
rank: int,
dp_size: int,
tp_size: int,
model_path: str,
master_port: int,
output_writer,
mem_fraction_static: float,
):
engine = None
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
dist.init_process_group(
rank=rank,
device_id=torch.device(f"cuda:{rank}"),
world_size=dp_size * tp_size,
)
torch.cuda.set_device(rank)
base_gpu_id = rank // tp_size * tp_size
mesh_kwargs = dict(
mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
)
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
print(
f"subprocess[{rank=},{base_gpu_id=},{rank=},{tp_size=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
)
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage before starting Engine: {_mem_usage}")
engine = EngineWrapper(
model_path=model_path,
random_seed=42,
mem_fraction_static=mem_fraction_static,
device_mesh_cpu=inference_device_mesh_cpu["tp"],
base_gpu_id=base_gpu_id,
)
print(f"subprocess[{rank=}] {engine=}", flush=True)
# 1 - release kv cache
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage before releasing Sgl KV cache: {_mem_usage}")
engine.release_memory_occupation(tags=["kv_cache"])
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage < _mem_usage
), f"Memory usage after releasing KV cache must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
# 2 - release sglang weights
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage before releasing Sgl weights: {_mem_usage}")
engine.release_memory_occupation(tags=["weights"])
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage < _mem_usage
), f"Memory usage after releasing weights must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
# 3 - load hf model
_mem_usage = get_gpu_memory_gb(rank)
print(
f"GPU{rank} Memory usage after releasing Sgl weights and kv cache: {_mem_usage}"
)
hf_model = AutoModelForCausalLM.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
torch_dtype="bfloat16",
device_map=f"cuda:{rank}",
trust_remote_code=True,
).cuda()
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage > _mem_usage
), f"Memory usage after loading hf model must be increased! before: {_mem_usage} vs after: {_curr_usage}"
# 4 - resume sglang weights and update the weights
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage after loading hf model: {_mem_usage}")
engine.resume_memory_occupation(tags=["weights"])
engine.update_weights_from_tensor(
named_tensors=list(hf_model.named_parameters())
)
# 5 - release hf model
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage after resuming Sgl weights: {_mem_usage}")
del hf_model
torch.cuda.empty_cache()
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage < _mem_usage
), f"Memory usage after releasing hf model must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
# 6 - resume slgang kv cache
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage after releasing hf model: {_mem_usage}")
engine.resume_memory_occupation(tags=["kv_cache"])
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage > _mem_usage
), f"Memory usage after resuming kv cache must be increased! before: {_mem_usage} vs after: {_curr_usage}"
# 7 - Final checking!
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage after resuming Sgl KV cache: {_mem_usage}")
execution_ok = True
except Exception as e:
print(f"subprocess[{rank=}] has error: {e}", flush=True)
traceback.print_exc()
execution_ok = False
output_writer.send(execution_ok)
output_writer.close()
if engine:
engine.shutdown()
if __name__ == "__main__":
unittest.main()