[Fix] Resolve GPU Memory Leak in update_weights_from_tensor (#4446)
This commit is contained in:
@@ -320,7 +320,10 @@ class Engine:
|
||||
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
|
||||
to avoid duplicated operations such as clearing cache."""
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
|
||||
serialized_named_tensors=[
|
||||
MultiprocessingSerializer.serialize(named_tensors)
|
||||
for _ in range(self.server_args.tp_size)
|
||||
],
|
||||
load_format=load_format,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
|
||||
@@ -214,7 +214,7 @@ class TpModelWorker:
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
named_tensors=MultiprocessingSerializer.deserialize(
|
||||
recv_req.serialized_named_tensors
|
||||
recv_req.serialized_named_tensors[self.tp_rank]
|
||||
),
|
||||
load_format=recv_req.load_format,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import time
|
||||
import unittest
|
||||
|
||||
@@ -7,24 +8,44 @@ import sglang as sgl
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
def test_update_weights_from_tensor(tp_size):
|
||||
assert torch.cuda.device_count() >= tp_size, f"At least {tp_size} GPUs are required"
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, tp_size=tp_size)
|
||||
|
||||
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)]
|
||||
|
||||
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
|
||||
|
||||
memory_before = torch.cuda.memory_allocated()
|
||||
new_tensor = torch.full((16384, 2048), 1.5, device="cuda")
|
||||
|
||||
time_start = time.time()
|
||||
engine.update_weights_from_tensor([(x, new_tensor) for x in param_names])
|
||||
print(f"Time delta: {time.time() - time_start:.03f}")
|
||||
|
||||
for param_name in param_names[:3]:
|
||||
_check_param(engine, param_name, [1.5] * 5)
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
del new_tensor
|
||||
gc.collect()
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.empty_cache()
|
||||
memory_after = torch.cuda.memory_allocated()
|
||||
assert (
|
||||
memory_after <= memory_before + 1024
|
||||
), f"Memory leak detected: {memory_after - memory_before} bytes"
|
||||
|
||||
|
||||
class TestUpdateWeightsFromTensor(unittest.TestCase):
|
||||
def test_update_weights_from_tensor(self):
|
||||
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)]
|
||||
|
||||
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
|
||||
|
||||
new_tensor = torch.full((16384, 2048), 1.5)
|
||||
|
||||
time_start = time.time()
|
||||
engine.update_weights_from_tensor([(x, new_tensor) for x in param_names])
|
||||
print(f"Time delta: {time.time() - time_start:.03f}")
|
||||
|
||||
for param_name in param_names[:3]:
|
||||
_check_param(engine, param_name, [1.5] * 5)
|
||||
|
||||
engine.shutdown()
|
||||
tp_sizes = [1, 2]
|
||||
for tp_size in tp_sizes:
|
||||
with self.subTest(tp_size=tp_size):
|
||||
test_update_weights_from_tensor(tp_size)
|
||||
|
||||
def test_update_weights_from_tensor_load_format_direct(self):
|
||||
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
Reference in New Issue
Block a user