Optimize: Cache CUDA device to reduce redundant calls during tensor l… (#8996)
This commit is contained in:
@@ -895,8 +895,12 @@ class ModelRunner:
|
|||||||
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
||||||
load_format: Optional[str] = None,
|
load_format: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
monkey_patch_torch_reductions()
|
||||||
|
# We need to get device after patch otherwise the device would be wrong
|
||||||
|
infered_device = torch.cuda.current_device()
|
||||||
|
|
||||||
named_tensors = [
|
named_tensors = [
|
||||||
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
||||||
for name, tensor in named_tensors
|
for name, tensor in named_tensors
|
||||||
]
|
]
|
||||||
if load_format == "direct":
|
if load_format == "direct":
|
||||||
@@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
|
|||||||
default_weight_loader(params_dict[name], tensor)
|
default_weight_loader(params_dict[name], tensor)
|
||||||
|
|
||||||
|
|
||||||
def _unwrap_tensor(tensor, tp_rank):
|
def _unwrap_tensor(tensor, tp_rank, device):
|
||||||
if isinstance(tensor, LocalSerializedTensor):
|
if isinstance(tensor, LocalSerializedTensor):
|
||||||
monkey_patch_torch_reductions()
|
|
||||||
tensor = tensor.get(tp_rank)
|
tensor = tensor.get(tp_rank)
|
||||||
return tensor.to(torch.cuda.current_device())
|
return tensor.to(device)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
Reference in New Issue
Block a user