Improve weight loading and code style (#3174)
This commit is contained in:
@@ -404,8 +404,13 @@ def np_cache_weights_iterator(
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
is_all_weights_sharded: bool = False,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
"""Iterate over the weights in the model safetensor files.
|
||||
|
||||
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
||||
entire file instead of reading each tensor one by one.
|
||||
"""
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
@@ -415,9 +420,14 @@ def safetensors_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
if not is_all_weights_sharded:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
else:
|
||||
result = load_file(st_file, device="cpu")
|
||||
for name, param in result.items():
|
||||
yield name, param
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user