diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 8dd0a4a15..db5e3b3cb 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -460,10 +460,12 @@ def safetensors_weights_iterator( if disable_mmap: with open(st_file, "rb") as f: result = safetensors.torch.load(f.read()) + for name, param in result.items(): + yield name, param else: - result = safetensors.torch.load_file(st_file, device="cpu") - for name, param in result.items(): - yield name, param + with safetensors.safe_open(st_file, framework="pt", device="cpu") as f: + for name in f.keys(): + yield name, f.get_tensor(name) def multi_thread_safetensors_weights_iterator( @@ -496,7 +498,8 @@ def multi_thread_safetensors_weights_iterator( with open(st_file, "rb") as f: result = safetensors.torch.load(f.read()) else: - result = safetensors.torch.load_file(st_file, device="cpu") + with safetensors.safe_open(st_file, framework="pt", device="cpu") as f: + result = {k: f.get_tensor(k) for k in f.keys()} return result