[Fix] Resolve GPU Memory Leak in update_weights_from_tensor (#4446)
This commit is contained in:
@@ -320,7 +320,10 @@ class Engine:
|
||||
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
|
||||
to avoid duplicated operations such as clearing cache."""
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
|
||||
serialized_named_tensors=[
|
||||
MultiprocessingSerializer.serialize(named_tensors)
|
||||
for _ in range(self.server_args.tp_size)
|
||||
],
|
||||
load_format=load_format,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user