Optimize: Cache CUDA device to reduce redundant calls during tensor l… (#8996)

This commit is contained in:
JiLi
2025-08-10 15:32:57 +08:00
committed by GitHub
parent 473400e452
commit 6b847a9a05

View File

@@ -895,8 +895,12 @@ class ModelRunner:
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
load_format: Optional[str] = None,
):
monkey_patch_torch_reductions()
# We need to get device after patch otherwise the device would be wrong
infered_device = torch.cuda.current_device()
named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
for name, tensor in named_tensors
]
if load_format == "direct":
@@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
default_weight_loader(params_dict[name], tensor)
def _unwrap_tensor(tensor, tp_rank):
def _unwrap_tensor(tensor, tp_rank, device):
if isinstance(tensor, LocalSerializedTensor):
monkey_patch_torch_reductions()
tensor = tensor.get(tp_rank)
return tensor.to(torch.cuda.current_device())
return tensor.to(device)
@dataclass