Support weight loading without mmap (#7469)

This commit is contained in:
Yuhong Guo
2025-06-24 06:13:59 +08:00
committed by GitHub
parent e5ddeb04d5
commit e5afb88b1c
4 changed files with 21 additions and 2 deletions

View File

@@ -422,6 +422,7 @@ def safetensors_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
disable_mmap: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.
@@ -443,7 +444,11 @@ def safetensors_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
result = safetensors.torch.load_file(st_file, device="cpu")
if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.torch.load(f.read())
else:
result = safetensors.torch.load_file(st_file, device="cpu")
for name, param in result.items():
yield name, param