Optimize: Cache CUDA device to reduce redundant calls during tensor l… (#8996)
This commit is contained in:
@@ -895,8 +895,12 @@ class ModelRunner:
|
||||
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
monkey_patch_torch_reductions()
|
||||
# We need to get device after patch otherwise the device would be wrong
|
||||
infered_device = torch.cuda.current_device()
|
||||
|
||||
named_tensors = [
|
||||
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
||||
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
||||
for name, tensor in named_tensors
|
||||
]
|
||||
if load_format == "direct":
|
||||
@@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
|
||||
default_weight_loader(params_dict[name], tensor)
|
||||
|
||||
|
||||
def _unwrap_tensor(tensor, tp_rank):
|
||||
def _unwrap_tensor(tensor, tp_rank, device):
|
||||
if isinstance(tensor, LocalSerializedTensor):
|
||||
monkey_patch_torch_reductions()
|
||||
tensor = tensor.get(tp_rank)
|
||||
return tensor.to(torch.cuda.current_device())
|
||||
return tensor.to(device)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user