From 34e5e11f0ff6f2111241915d88f59fc44dfcf200 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 23 Aug 2025 17:07:15 +0800 Subject: [PATCH] Tiny make device_loading_context more static (#9478) --- python/sglang/srt/model_loader/loader.py | 30 +++++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 95d41a050..23d70be44 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) yield module return - original_device_states: Dict[str, torch.device] = {} + original_infos: Dict[str, Dict] = {} # Store original device states and move parameters to GPU if they're on CPU for name, p in module.named_parameters(): if p.device.type == "cpu": - original_device_states[name] = p.device - p.data = p.data.to(target_device) + original_data = p.data + device_data = p.data.to(target_device) + original_infos[name] = dict( + device=p.device, + original_data=original_data, + device_data=device_data, + ) + p.data = device_data # Parameters already on target device are not touched try: @@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) # Restore parameters to their original devices, ignoring new parameters pin_memory = is_pin_memory_available() for name, p in module.named_parameters(): - if name in original_device_states: - original_device: torch.device = original_device_states[name] - if original_device.type == "cpu": + if name in original_infos: + original_info = original_infos[name] + device_data = original_info["device_data"] + original_data = original_info["original_data"] + original_device: torch.device = original_info["device"] + + if ( + (device_data.device == p.data.device) + and (device_data.data_ptr() == p.data.data_ptr()) + and (device_data.shape == p.data.shape) + and (device_data.dtype == p.data.dtype) + ): + original_data.copy_(p.data.to(original_data.device)) + p.data = original_data + elif original_device.type == "cpu": # `torch.empty_like` does not support `pin_memory` argument cpu_data = torch.empty_strided( size=p.data.size(),