Add update_weights_from_tensor (#2631)
This commit is contained in:
@@ -429,6 +429,10 @@ class ModelRunner:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
|
||||
self.model.load_weights([(name, tensor)])
|
||||
return True, "Success" # TODO error handling
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user