Tiny make device_loading_context more static (#9478)
This commit is contained in:
@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
|||||||
yield module
|
yield module
|
||||||
return
|
return
|
||||||
|
|
||||||
original_device_states: Dict[str, torch.device] = {}
|
original_infos: Dict[str, Dict] = {}
|
||||||
|
|
||||||
# Store original device states and move parameters to GPU if they're on CPU
|
# Store original device states and move parameters to GPU if they're on CPU
|
||||||
for name, p in module.named_parameters():
|
for name, p in module.named_parameters():
|
||||||
if p.device.type == "cpu":
|
if p.device.type == "cpu":
|
||||||
original_device_states[name] = p.device
|
original_data = p.data
|
||||||
p.data = p.data.to(target_device)
|
device_data = p.data.to(target_device)
|
||||||
|
original_infos[name] = dict(
|
||||||
|
device=p.device,
|
||||||
|
original_data=original_data,
|
||||||
|
device_data=device_data,
|
||||||
|
)
|
||||||
|
p.data = device_data
|
||||||
# Parameters already on target device are not touched
|
# Parameters already on target device are not touched
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
|||||||
# Restore parameters to their original devices, ignoring new parameters
|
# Restore parameters to their original devices, ignoring new parameters
|
||||||
pin_memory = is_pin_memory_available()
|
pin_memory = is_pin_memory_available()
|
||||||
for name, p in module.named_parameters():
|
for name, p in module.named_parameters():
|
||||||
if name in original_device_states:
|
if name in original_infos:
|
||||||
original_device: torch.device = original_device_states[name]
|
original_info = original_infos[name]
|
||||||
if original_device.type == "cpu":
|
device_data = original_info["device_data"]
|
||||||
|
original_data = original_info["original_data"]
|
||||||
|
original_device: torch.device = original_info["device"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
(device_data.device == p.data.device)
|
||||||
|
and (device_data.data_ptr() == p.data.data_ptr())
|
||||||
|
and (device_data.shape == p.data.shape)
|
||||||
|
and (device_data.dtype == p.data.dtype)
|
||||||
|
):
|
||||||
|
original_data.copy_(p.data.to(original_data.device))
|
||||||
|
p.data = original_data
|
||||||
|
elif original_device.type == "cpu":
|
||||||
# `torch.empty_like` does not support `pin_memory` argument
|
# `torch.empty_like` does not support `pin_memory` argument
|
||||||
cpu_data = torch.empty_strided(
|
cpu_data = torch.empty_strided(
|
||||||
size=p.data.size(),
|
size=p.data.size(),
|
||||||
|
|||||||
Reference in New Issue
Block a user