[Fix] Resolve GPU Memory Leak in update_weights_from_tensor (#4446)

This commit is contained in:
Wei Wu
2025-03-17 16:54:30 +08:00
committed by GitHub
parent c614dbdf95
commit 91ba98fe50
3 changed files with 42 additions and 18 deletions

View File

@@ -320,7 +320,10 @@ class Engine:
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
to avoid duplicated operations such as clearing cache."""
obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
serialized_named_tensors=[
MultiprocessingSerializer.serialize(named_tensors)
for _ in range(self.server_args.tp_size)
],
load_format=load_format,
flush_cache=flush_cache,
)

View File

@@ -214,7 +214,7 @@ class TpModelWorker:
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.model_runner.update_weights_from_tensor(
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors
recv_req.serialized_named_tensors[self.tp_rank]
),
load_format=recv_req.load_format,
)