Tiny make device_loading_context more static (#9478)
This commit is contained in:
@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: Dict[str, torch.device] = {}
|
||||
original_infos: Dict[str, Dict] = {}
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
original_data = p.data
|
||||
device_data = p.data.to(target_device)
|
||||
original_infos[name] = dict(
|
||||
device=p.device,
|
||||
original_data=original_data,
|
||||
device_data=device_data,
|
||||
)
|
||||
p.data = device_data
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
pin_memory = is_pin_memory_available()
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
if original_device.type == "cpu":
|
||||
if name in original_infos:
|
||||
original_info = original_infos[name]
|
||||
device_data = original_info["device_data"]
|
||||
original_data = original_info["original_data"]
|
||||
original_device: torch.device = original_info["device"]
|
||||
|
||||
if (
|
||||
(device_data.device == p.data.device)
|
||||
and (device_data.data_ptr() == p.data.data_ptr())
|
||||
and (device_data.shape == p.data.shape)
|
||||
and (device_data.dtype == p.data.dtype)
|
||||
):
|
||||
original_data.copy_(p.data.to(original_data.device))
|
||||
p.data = original_data
|
||||
elif original_device.type == "cpu":
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
|
||||
Reference in New Issue
Block a user