Speed up update_weights_from_tensor (#2695)
This commit is contained in:
@@ -17,7 +17,7 @@ import gc
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -428,9 +428,9 @@ class ModelRunner:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
|
||||
self.model.load_weights([(name, tensor)])
|
||||
return True, "Success" # TODO error handling
|
||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
self.model.load_weights(named_tensors)
|
||||
return True, "Success"
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
|
||||
Reference in New Issue
Block a user