Speed up update_weights_from_tensor (#2695)

This commit is contained in:
fzyzcjy
2025-01-02 18:05:19 +08:00
committed by GitHub
parent 148254d4db
commit 9183c23eca
6 changed files with 48 additions and 25 deletions

View File

@@ -17,7 +17,7 @@ import gc
import json
import logging
import time
from typing import Optional
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
@@ -428,9 +428,9 @@ class ModelRunner:
logger.error(error_msg)
return False, error_msg
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
self.model.load_weights([(name, tensor)])
return True, "Success" # TODO error handling
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
self.model.load_weights(named_tensors)
return True, "Success"
def get_weights_by_name(
self, name: str, truncate_size: int = 100