Add e2e test for multi instance multi stage memory release/resume occupuation (#7208)
Signed-off-by: Ata Fatahi <immrata@gmail.com>
This commit is contained in:
@@ -160,6 +160,7 @@ suites = {
|
||||
"per-commit-4-gpu": [
|
||||
TestFile("test_local_attn.py", 250),
|
||||
TestFile("test_pp_single_node.py", 150),
|
||||
TestFile("test_multi_instance_release_memory_occupation.py", 64),
|
||||
],
|
||||
"per-commit-4-gpu-amd": [
|
||||
TestFile("test_pp_single_node.py", 150),
|
||||
|
||||
250
test/srt/test_multi_instance_release_memory_occupation.py
Normal file
250
test/srt/test_multi_instance_release_memory_occupation.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import subprocess
|
||||
import traceback
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from sglang.srt.entrypoints.engine import Engine as SglangEngine
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
||||
CustomTestCase,
|
||||
find_available_port,
|
||||
)
|
||||
|
||||
TEST_SUITE = dict(
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
mem_fraction_static=0.85,
|
||||
dp_size=2,
|
||||
tp_size=2,
|
||||
)
|
||||
|
||||
|
||||
class EngineWrapper:
|
||||
"""
|
||||
A wrapper around Sglang engine to mock multi instance cases such as RL traing.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_path, random_seed, mem_fraction_static, device_mesh_cpu, base_gpu_id
|
||||
):
|
||||
self._device_mesh_cpu = device_mesh_cpu
|
||||
self._tp_rank = device_mesh_cpu.get_local_rank()
|
||||
self._rank = device_mesh_cpu.get_rank()
|
||||
self._tp_size = device_mesh_cpu.size()
|
||||
tp_size_per_node = self._tp_size
|
||||
node_rank = self._tp_rank // tp_size_per_node
|
||||
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
|
||||
engine_kwargs = dict(
|
||||
model_path=model_path,
|
||||
random_seed=random_seed,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
base_gpu_id=base_gpu_id,
|
||||
enable_memory_saver=True,
|
||||
tp_size=self._tp_size,
|
||||
node_rank=node_rank,
|
||||
nnodes=1,
|
||||
)
|
||||
self._engine = None
|
||||
if first_rank_in_node:
|
||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||
self._engine = SglangEngine(**engine_kwargs)
|
||||
|
||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||
|
||||
def update_weights_from_tensor(
|
||||
self, named_tensors: Iterable[Tuple[str, torch.Tensor]]
|
||||
):
|
||||
if self._tp_rank == 0:
|
||||
self._engine.update_weights_from_tensor(list(named_tensors))
|
||||
self._engine.flush_cache()
|
||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||
|
||||
def release_memory_occupation(self, tags):
|
||||
if self._tp_rank == 0:
|
||||
self._engine.release_memory_occupation(tags)
|
||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||
|
||||
def resume_memory_occupation(self, tags):
|
||||
if self._tp_rank == 0:
|
||||
self._engine.resume_memory_occupation(tags)
|
||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||
|
||||
def shutdown(self):
|
||||
if self._tp_rank == 0:
|
||||
self._engine.shutdown()
|
||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||
|
||||
|
||||
def get_gpu_memory_gb(gpu_id=0):
|
||||
return torch.cuda.device_memory_used() / 1024**3
|
||||
|
||||
|
||||
class TestMultiInstanceReleaseMemoryOccupation(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
multiprocessing.set_start_method("spawn")
|
||||
|
||||
def test_multi_instance_release_memory_occupation(self):
|
||||
master_port = find_available_port(23456)
|
||||
|
||||
dp_size = TEST_SUITE["dp_size"]
|
||||
tp_size = TEST_SUITE["tp_size"]
|
||||
world_size = dp_size * tp_size
|
||||
processes = []
|
||||
output_reader, output_writer = multiprocessing.Pipe(duplex=False)
|
||||
for rank in range(world_size):
|
||||
p = Process(
|
||||
target=_run_sglang_subprocess,
|
||||
kwargs=dict(
|
||||
rank=rank,
|
||||
dp_size=dp_size,
|
||||
tp_size=tp_size,
|
||||
model_path=TEST_SUITE["model_path"],
|
||||
master_port=master_port,
|
||||
output_writer=output_writer,
|
||||
mem_fraction_static=TEST_SUITE["mem_fraction_static"],
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for _ in range(world_size):
|
||||
self.assertTrue(
|
||||
output_reader.recv(), f"Subprocess fail. Check the logs above."
|
||||
)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
def _run_sglang_subprocess(
|
||||
rank: int,
|
||||
dp_size: int,
|
||||
tp_size: int,
|
||||
model_path: str,
|
||||
master_port: int,
|
||||
output_writer,
|
||||
mem_fraction_static: float,
|
||||
):
|
||||
engine = None
|
||||
try:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
dist.init_process_group(
|
||||
rank=rank,
|
||||
device_id=torch.device(f"cuda:{rank}"),
|
||||
world_size=dp_size * tp_size,
|
||||
)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
base_gpu_id = rank // tp_size * tp_size
|
||||
mesh_kwargs = dict(
|
||||
mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
|
||||
)
|
||||
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
|
||||
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
|
||||
print(
|
||||
f"subprocess[{rank=},{base_gpu_id=},{rank=},{tp_size=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
|
||||
)
|
||||
|
||||
_mem_usage = get_gpu_memory_gb(rank)
|
||||
print(f"GPU{rank} Memory usage before starting Engine: {_mem_usage}")
|
||||
|
||||
engine = EngineWrapper(
|
||||
model_path=model_path,
|
||||
random_seed=42,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
device_mesh_cpu=inference_device_mesh_cpu["tp"],
|
||||
base_gpu_id=base_gpu_id,
|
||||
)
|
||||
print(f"subprocess[{rank=}] {engine=}", flush=True)
|
||||
|
||||
# 1 - release kv cache
|
||||
_mem_usage = get_gpu_memory_gb(rank)
|
||||
print(f"GPU{rank} Memory usage before releasing Sgl KV cache: {_mem_usage}")
|
||||
engine.release_memory_occupation(tags=["kv_cache"])
|
||||
_curr_usage = get_gpu_memory_gb(rank)
|
||||
assert (
|
||||
_curr_usage < _mem_usage
|
||||
), f"Memory usage after releasing KV cache must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
|
||||
|
||||
# 2 - release sglang weights
|
||||
_mem_usage = get_gpu_memory_gb(rank)
|
||||
print(f"GPU{rank} Memory usage before releasing Sgl weights: {_mem_usage}")
|
||||
engine.release_memory_occupation(tags=["weights"])
|
||||
|
||||
_curr_usage = get_gpu_memory_gb(rank)
|
||||
assert (
|
||||
_curr_usage < _mem_usage
|
||||
), f"Memory usage after releasing weights must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
|
||||
|
||||
# 3 - load hf model
|
||||
_mem_usage = get_gpu_memory_gb(rank)
|
||||
print(
|
||||
f"GPU{rank} Memory usage after releasing Sgl weights and kv cache: {_mem_usage}"
|
||||
)
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
||||
torch_dtype="bfloat16",
|
||||
device_map=f"cuda:{rank}",
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
_curr_usage = get_gpu_memory_gb(rank)
|
||||
assert (
|
||||
_curr_usage > _mem_usage
|
||||
), f"Memory usage after loading hf model must be increased! before: {_mem_usage} vs after: {_curr_usage}"
|
||||
|
||||
# 4 - resume sglang weights and update the weights
|
||||
_mem_usage = get_gpu_memory_gb(rank)
|
||||
print(f"GPU{rank} Memory usage after loading hf model: {_mem_usage}")
|
||||
engine.resume_memory_occupation(tags=["weights"])
|
||||
engine.update_weights_from_tensor(
|
||||
named_tensors=list(hf_model.named_parameters())
|
||||
)
|
||||
|
||||
# 5 - release hf model
|
||||
_mem_usage = get_gpu_memory_gb(rank)
|
||||
print(f"GPU{rank} Memory usage after resuming Sgl weights: {_mem_usage}")
|
||||
del hf_model
|
||||
torch.cuda.empty_cache()
|
||||
_curr_usage = get_gpu_memory_gb(rank)
|
||||
assert (
|
||||
_curr_usage < _mem_usage
|
||||
), f"Memory usage after releasing hf model must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
|
||||
|
||||
# 6 - resume slgang kv cache
|
||||
_mem_usage = get_gpu_memory_gb(rank)
|
||||
print(f"GPU{rank} Memory usage after releasing hf model: {_mem_usage}")
|
||||
engine.resume_memory_occupation(tags=["kv_cache"])
|
||||
_curr_usage = get_gpu_memory_gb(rank)
|
||||
assert (
|
||||
_curr_usage > _mem_usage
|
||||
), f"Memory usage after resuming kv cache must be increased! before: {_mem_usage} vs after: {_curr_usage}"
|
||||
|
||||
# 7 - Final checking!
|
||||
_mem_usage = get_gpu_memory_gb(rank)
|
||||
print(f"GPU{rank} Memory usage after resuming Sgl KV cache: {_mem_usage}")
|
||||
|
||||
execution_ok = True
|
||||
except Exception as e:
|
||||
print(f"subprocess[{rank=}] has error: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
execution_ok = False
|
||||
|
||||
output_writer.send(execution_ok)
|
||||
output_writer.close()
|
||||
|
||||
if engine:
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user