Improve weight loading and code style (#3174)

This commit is contained in:
Lianmin Zheng
2025-01-27 03:00:41 -08:00
committed by GitHub
parent 351a72d40b
commit 53cef81587
11 changed files with 171 additions and 65 deletions

View File

@@ -404,8 +404,13 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
"""Iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one.
"""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
@@ -415,9 +420,14 @@ def safetensors_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
if not is_all_weights_sharded:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
result = load_file(st_file, device="cpu")
for name, param in result.items():
yield name, param