Add update_weights_from_tensor (#2631)

This commit is contained in:
fzyzcjy
2024-12-29 05:30:27 +08:00
committed by GitHub
parent 7863e4368a
commit fd28640dc5
10 changed files with 120 additions and 1 deletions

View File

@@ -429,6 +429,10 @@ class ModelRunner:
logger.error(error_msg)
return False, error_msg
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
self.model.load_weights([(name, tensor)])
return True, "Success" # TODO error handling
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]: