fix(model loader): use safe_open to prevent file handle leaks. (#7684)
This commit is contained in:
@@ -460,10 +460,12 @@ def safetensors_weights_iterator(
|
|||||||
if disable_mmap:
|
if disable_mmap:
|
||||||
with open(st_file, "rb") as f:
|
with open(st_file, "rb") as f:
|
||||||
result = safetensors.torch.load(f.read())
|
result = safetensors.torch.load(f.read())
|
||||||
|
for name, param in result.items():
|
||||||
|
yield name, param
|
||||||
else:
|
else:
|
||||||
result = safetensors.torch.load_file(st_file, device="cpu")
|
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
|
||||||
for name, param in result.items():
|
for name in f.keys():
|
||||||
yield name, param
|
yield name, f.get_tensor(name)
|
||||||
|
|
||||||
|
|
||||||
def multi_thread_safetensors_weights_iterator(
|
def multi_thread_safetensors_weights_iterator(
|
||||||
@@ -496,7 +498,8 @@ def multi_thread_safetensors_weights_iterator(
|
|||||||
with open(st_file, "rb") as f:
|
with open(st_file, "rb") as f:
|
||||||
result = safetensors.torch.load(f.read())
|
result = safetensors.torch.load(f.read())
|
||||||
else:
|
else:
|
||||||
result = safetensors.torch.load_file(st_file, device="cpu")
|
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
|
||||||
|
result = {k: f.get_tensor(k) for k in f.keys()}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user