diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 26fa77b88..a100e2785 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -895,8 +895,12 @@ class ModelRunner: named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]], load_format: Optional[str] = None, ): + monkey_patch_torch_reductions() + # We need to get device after patch otherwise the device would be wrong + infered_device = torch.cuda.current_device() + named_tensors = [ - (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank)) + (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device)) for name, tensor in named_tensors ] if load_format == "direct": @@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso default_weight_loader(params_dict[name], tensor) -def _unwrap_tensor(tensor, tp_rank): +def _unwrap_tensor(tensor, tp_rank, device): if isinstance(tensor, LocalSerializedTensor): - monkey_patch_torch_reductions() tensor = tensor.get(tp_rank) - return tensor.to(torch.cuda.current_device()) + return tensor.to(device) @dataclass