From 6b847a9a05b3c0870968bd5a274207349034cf4e Mon Sep 17 00:00:00 2001 From: JiLi Date: Sun, 10 Aug 2025 15:32:57 +0800 Subject: [PATCH] =?UTF-8?q?Optimize:=20Cache=20CUDA=20device=20to=20reduce?= =?UTF-8?q?=20redundant=20calls=20during=20tensor=20l=E2=80=A6=20(#8996)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/sglang/srt/model_executor/model_runner.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 26fa77b88..a100e2785 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -895,8 +895,12 @@ class ModelRunner: named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]], load_format: Optional[str] = None, ): + monkey_patch_torch_reductions() + # We need to get device after patch otherwise the device would be wrong + infered_device = torch.cuda.current_device() + named_tensors = [ - (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank)) + (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device)) for name, tensor in named_tensors ] if load_format == "direct": @@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso default_weight_loader(params_dict[name], tensor) -def _unwrap_tensor(tensor, tp_rank): +def _unwrap_tensor(tensor, tp_rank, device): if isinstance(tensor, LocalSerializedTensor): - monkey_patch_torch_reductions() tensor = tensor.get(tp_rank) - return tensor.to(torch.cuda.current_device()) + return tensor.to(device) @dataclass