[Fix] Resolve GPU Memory Leak in update_weights_from_tensor (#4446)
This commit is contained in:
@@ -320,7 +320,10 @@ class Engine:
|
||||
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
|
||||
to avoid duplicated operations such as clearing cache."""
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
|
||||
serialized_named_tensors=[
|
||||
MultiprocessingSerializer.serialize(named_tensors)
|
||||
for _ in range(self.server_args.tp_size)
|
||||
],
|
||||
load_format=load_format,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
|
||||
@@ -214,7 +214,7 @@ class TpModelWorker:
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
named_tensors=MultiprocessingSerializer.deserialize(
|
||||
recv_req.serialized_named_tensors
|
||||
recv_req.serialized_named_tensors[self.tp_rank]
|
||||
),
|
||||
load_format=recv_req.load_format,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user