Support weight loading without mmap (#7469)
This commit is contained in:
@@ -422,6 +422,7 @@ def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
is_all_weights_sharded: bool = False,
|
||||
decryption_key: Optional[str] = None,
|
||||
disable_mmap: bool = False,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files.
|
||||
|
||||
@@ -443,7 +444,11 @@ def safetensors_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
result = safetensors.torch.load_file(st_file, device="cpu")
|
||||
if disable_mmap:
|
||||
with open(st_file, "rb") as f:
|
||||
result = safetensors.torch.load(f.read())
|
||||
else:
|
||||
result = safetensors.torch.load_file(st_file, device="cpu")
|
||||
for name, param in result.items():
|
||||
yield name, param
|
||||
|
||||
|
||||
Reference in New Issue
Block a user