[Fix] Resolve GPU Memory Leak in update_weights_from_tensor (#4446)

This commit is contained in:
Wei Wu
2025-03-17 16:54:30 +08:00
committed by GitHub
parent c614dbdf95
commit 91ba98fe50
3 changed files with 42 additions and 18 deletions

View File

@@ -320,7 +320,10 @@ class Engine:
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
to avoid duplicated operations such as clearing cache."""
obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
serialized_named_tensors=[
MultiprocessingSerializer.serialize(named_tensors)
for _ in range(self.server_args.tp_size)
],
load_format=load_format,
flush_cache=flush_cache,
)