[Fix] Resolve GPU Memory Leak in update_weights_from_tensor (#4446)
This commit is contained in:
@@ -320,7 +320,10 @@ class Engine:
|
|||||||
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
|
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
|
||||||
to avoid duplicated operations such as clearing cache."""
|
to avoid duplicated operations such as clearing cache."""
|
||||||
obj = UpdateWeightsFromTensorReqInput(
|
obj = UpdateWeightsFromTensorReqInput(
|
||||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
|
serialized_named_tensors=[
|
||||||
|
MultiprocessingSerializer.serialize(named_tensors)
|
||||||
|
for _ in range(self.server_args.tp_size)
|
||||||
|
],
|
||||||
load_format=load_format,
|
load_format=load_format,
|
||||||
flush_cache=flush_cache,
|
flush_cache=flush_cache,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ class TpModelWorker:
|
|||||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
success, message = self.model_runner.update_weights_from_tensor(
|
success, message = self.model_runner.update_weights_from_tensor(
|
||||||
named_tensors=MultiprocessingSerializer.deserialize(
|
named_tensors=MultiprocessingSerializer.deserialize(
|
||||||
recv_req.serialized_named_tensors
|
recv_req.serialized_named_tensors[self.tp_rank]
|
||||||
),
|
),
|
||||||
load_format=recv_req.load_format,
|
load_format=recv_req.load_format,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import gc
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -7,24 +8,44 @@ import sglang as sgl
|
|||||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_weights_from_tensor(tp_size):
|
||||||
|
assert torch.cuda.device_count() >= tp_size, f"At least {tp_size} GPUs are required"
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, tp_size=tp_size)
|
||||||
|
|
||||||
|
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)]
|
||||||
|
|
||||||
|
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
|
||||||
|
|
||||||
|
memory_before = torch.cuda.memory_allocated()
|
||||||
|
new_tensor = torch.full((16384, 2048), 1.5, device="cuda")
|
||||||
|
|
||||||
|
time_start = time.time()
|
||||||
|
engine.update_weights_from_tensor([(x, new_tensor) for x in param_names])
|
||||||
|
print(f"Time delta: {time.time() - time_start:.03f}")
|
||||||
|
|
||||||
|
for param_name in param_names[:3]:
|
||||||
|
_check_param(engine, param_name, [1.5] * 5)
|
||||||
|
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
del new_tensor
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
memory_after = torch.cuda.memory_allocated()
|
||||||
|
assert (
|
||||||
|
memory_after <= memory_before + 1024
|
||||||
|
), f"Memory leak detected: {memory_after - memory_before} bytes"
|
||||||
|
|
||||||
|
|
||||||
class TestUpdateWeightsFromTensor(unittest.TestCase):
|
class TestUpdateWeightsFromTensor(unittest.TestCase):
|
||||||
def test_update_weights_from_tensor(self):
|
def test_update_weights_from_tensor(self):
|
||||||
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
tp_sizes = [1, 2]
|
||||||
|
for tp_size in tp_sizes:
|
||||||
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)]
|
with self.subTest(tp_size=tp_size):
|
||||||
|
test_update_weights_from_tensor(tp_size)
|
||||||
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
|
|
||||||
|
|
||||||
new_tensor = torch.full((16384, 2048), 1.5)
|
|
||||||
|
|
||||||
time_start = time.time()
|
|
||||||
engine.update_weights_from_tensor([(x, new_tensor) for x in param_names])
|
|
||||||
print(f"Time delta: {time.time() - time_start:.03f}")
|
|
||||||
|
|
||||||
for param_name in param_names[:3]:
|
|
||||||
_check_param(engine, param_name, [1.5] * 5)
|
|
||||||
|
|
||||||
engine.shutdown()
|
|
||||||
|
|
||||||
def test_update_weights_from_tensor_load_format_direct(self):
|
def test_update_weights_from_tensor_load_format_direct(self):
|
||||||
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
|||||||
Reference in New Issue
Block a user