251 lines
8.5 KiB
Python
251 lines
8.5 KiB
Python
import multiprocessing
|
|
import os
|
|
import subprocess
|
|
import traceback
|
|
import unittest
|
|
from multiprocessing import Process
|
|
from typing import Iterable, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
from sglang.srt.entrypoints.engine import Engine as SglangEngine
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
|
CustomTestCase,
|
|
find_available_port,
|
|
)
|
|
|
|
TEST_SUITE = dict(
|
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
mem_fraction_static=0.85,
|
|
dp_size=2,
|
|
tp_size=2,
|
|
)
|
|
|
|
|
|
class EngineWrapper:
|
|
"""
|
|
A wrapper around Sglang engine to mock multi instance cases such as RL traing.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self, model_path, random_seed, mem_fraction_static, device_mesh_cpu, base_gpu_id
|
|
):
|
|
self._device_mesh_cpu = device_mesh_cpu
|
|
self._tp_rank = device_mesh_cpu.get_local_rank()
|
|
self._rank = device_mesh_cpu.get_rank()
|
|
self._tp_size = device_mesh_cpu.size()
|
|
tp_size_per_node = self._tp_size
|
|
node_rank = self._tp_rank // tp_size_per_node
|
|
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
|
|
engine_kwargs = dict(
|
|
model_path=model_path,
|
|
random_seed=random_seed,
|
|
mem_fraction_static=mem_fraction_static,
|
|
base_gpu_id=base_gpu_id,
|
|
enable_memory_saver=True,
|
|
tp_size=self._tp_size,
|
|
node_rank=node_rank,
|
|
nnodes=1,
|
|
)
|
|
self._engine = None
|
|
if first_rank_in_node:
|
|
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
|
self._engine = SglangEngine(**engine_kwargs)
|
|
|
|
dist.barrier(group=self._device_mesh_cpu.get_group())
|
|
|
|
def update_weights_from_tensor(
|
|
self, named_tensors: Iterable[Tuple[str, torch.Tensor]]
|
|
):
|
|
if self._tp_rank == 0:
|
|
self._engine.update_weights_from_tensor(list(named_tensors))
|
|
self._engine.flush_cache()
|
|
dist.barrier(group=self._device_mesh_cpu.get_group())
|
|
|
|
def release_memory_occupation(self, tags):
|
|
if self._tp_rank == 0:
|
|
self._engine.release_memory_occupation(tags)
|
|
dist.barrier(group=self._device_mesh_cpu.get_group())
|
|
|
|
def resume_memory_occupation(self, tags):
|
|
if self._tp_rank == 0:
|
|
self._engine.resume_memory_occupation(tags)
|
|
dist.barrier(group=self._device_mesh_cpu.get_group())
|
|
|
|
def shutdown(self):
|
|
if self._tp_rank == 0:
|
|
self._engine.shutdown()
|
|
dist.barrier(group=self._device_mesh_cpu.get_group())
|
|
|
|
|
|
def get_gpu_memory_gb(gpu_id=0):
|
|
return torch.cuda.device_memory_used() / 1024**3
|
|
|
|
|
|
class TestMultiInstanceReleaseMemoryOccupation(CustomTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
multiprocessing.set_start_method("spawn")
|
|
|
|
def test_multi_instance_release_memory_occupation(self):
|
|
master_port = find_available_port(23456)
|
|
|
|
dp_size = TEST_SUITE["dp_size"]
|
|
tp_size = TEST_SUITE["tp_size"]
|
|
world_size = dp_size * tp_size
|
|
processes = []
|
|
output_reader, output_writer = multiprocessing.Pipe(duplex=False)
|
|
for rank in range(world_size):
|
|
p = Process(
|
|
target=_run_sglang_subprocess,
|
|
kwargs=dict(
|
|
rank=rank,
|
|
dp_size=dp_size,
|
|
tp_size=tp_size,
|
|
model_path=TEST_SUITE["model_path"],
|
|
master_port=master_port,
|
|
output_writer=output_writer,
|
|
mem_fraction_static=TEST_SUITE["mem_fraction_static"],
|
|
),
|
|
)
|
|
p.start()
|
|
processes.append(p)
|
|
|
|
for _ in range(world_size):
|
|
self.assertTrue(
|
|
output_reader.recv(), f"Subprocess fail. Check the logs above."
|
|
)
|
|
for p in processes:
|
|
p.join()
|
|
|
|
|
|
def _run_sglang_subprocess(
|
|
rank: int,
|
|
dp_size: int,
|
|
tp_size: int,
|
|
model_path: str,
|
|
master_port: int,
|
|
output_writer,
|
|
mem_fraction_static: float,
|
|
):
|
|
engine = None
|
|
try:
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = str(master_port)
|
|
dist.init_process_group(
|
|
rank=rank,
|
|
device_id=torch.device(f"cuda:{rank}"),
|
|
world_size=dp_size * tp_size,
|
|
)
|
|
torch.cuda.set_device(rank)
|
|
|
|
base_gpu_id = rank // tp_size * tp_size
|
|
mesh_kwargs = dict(
|
|
mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
|
|
)
|
|
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
|
|
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
|
|
print(
|
|
f"subprocess[{rank=},{base_gpu_id=},{rank=},{tp_size=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
|
|
)
|
|
|
|
_mem_usage = get_gpu_memory_gb(rank)
|
|
print(f"GPU{rank} Memory usage before starting Engine: {_mem_usage}")
|
|
|
|
engine = EngineWrapper(
|
|
model_path=model_path,
|
|
random_seed=42,
|
|
mem_fraction_static=mem_fraction_static,
|
|
device_mesh_cpu=inference_device_mesh_cpu["tp"],
|
|
base_gpu_id=base_gpu_id,
|
|
)
|
|
print(f"subprocess[{rank=}] {engine=}", flush=True)
|
|
|
|
# 1 - release kv cache
|
|
_mem_usage = get_gpu_memory_gb(rank)
|
|
print(f"GPU{rank} Memory usage before releasing Sgl KV cache: {_mem_usage}")
|
|
engine.release_memory_occupation(tags=["kv_cache"])
|
|
_curr_usage = get_gpu_memory_gb(rank)
|
|
assert (
|
|
_curr_usage < _mem_usage
|
|
), f"Memory usage after releasing KV cache must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
|
|
|
|
# 2 - release sglang weights
|
|
_mem_usage = get_gpu_memory_gb(rank)
|
|
print(f"GPU{rank} Memory usage before releasing Sgl weights: {_mem_usage}")
|
|
engine.release_memory_occupation(tags=["weights"])
|
|
|
|
_curr_usage = get_gpu_memory_gb(rank)
|
|
assert (
|
|
_curr_usage < _mem_usage
|
|
), f"Memory usage after releasing weights must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
|
|
|
|
# 3 - load hf model
|
|
_mem_usage = get_gpu_memory_gb(rank)
|
|
print(
|
|
f"GPU{rank} Memory usage after releasing Sgl weights and kv cache: {_mem_usage}"
|
|
)
|
|
hf_model = AutoModelForCausalLM.from_pretrained(
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
|
|
torch_dtype="bfloat16",
|
|
device_map=f"cuda:{rank}",
|
|
trust_remote_code=True,
|
|
).cuda()
|
|
_curr_usage = get_gpu_memory_gb(rank)
|
|
assert (
|
|
_curr_usage > _mem_usage
|
|
), f"Memory usage after loading hf model must be increased! before: {_mem_usage} vs after: {_curr_usage}"
|
|
|
|
# 4 - resume sglang weights and update the weights
|
|
_mem_usage = get_gpu_memory_gb(rank)
|
|
print(f"GPU{rank} Memory usage after loading hf model: {_mem_usage}")
|
|
engine.resume_memory_occupation(tags=["weights"])
|
|
engine.update_weights_from_tensor(
|
|
named_tensors=list(hf_model.named_parameters())
|
|
)
|
|
|
|
# 5 - release hf model
|
|
_mem_usage = get_gpu_memory_gb(rank)
|
|
print(f"GPU{rank} Memory usage after resuming Sgl weights: {_mem_usage}")
|
|
del hf_model
|
|
torch.cuda.empty_cache()
|
|
_curr_usage = get_gpu_memory_gb(rank)
|
|
assert (
|
|
_curr_usage < _mem_usage
|
|
), f"Memory usage after releasing hf model must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
|
|
|
|
# 6 - resume slgang kv cache
|
|
_mem_usage = get_gpu_memory_gb(rank)
|
|
print(f"GPU{rank} Memory usage after releasing hf model: {_mem_usage}")
|
|
engine.resume_memory_occupation(tags=["kv_cache"])
|
|
_curr_usage = get_gpu_memory_gb(rank)
|
|
assert (
|
|
_curr_usage > _mem_usage
|
|
), f"Memory usage after resuming kv cache must be increased! before: {_mem_usage} vs after: {_curr_usage}"
|
|
|
|
# 7 - Final checking!
|
|
_mem_usage = get_gpu_memory_gb(rank)
|
|
print(f"GPU{rank} Memory usage after resuming Sgl KV cache: {_mem_usage}")
|
|
|
|
execution_ok = True
|
|
except Exception as e:
|
|
print(f"subprocess[{rank=}] has error: {e}", flush=True)
|
|
traceback.print_exc()
|
|
execution_ok = False
|
|
|
|
output_writer.send(execution_ok)
|
|
output_writer.close()
|
|
|
|
if engine:
|
|
engine.shutdown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|