Support multi-thread model weight loading (#7277)
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
|
||||
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
|
||||
yield name, param
|
||||
|
||||
|
||||
def multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
is_all_weights_sharded: bool = False,
|
||||
decryption_key: Optional[str] = None,
|
||||
max_workers: int = 4,
|
||||
disable_mmap: bool = False,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Multi-Thread iterate over the weights in the model safetensor files.
|
||||
|
||||
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
||||
entire file instead of reading each tensor one by one.
|
||||
"""
|
||||
if decryption_key:
|
||||
logger.warning(
|
||||
"Multi-Thread loading is not working for encrypted safetensor weights."
|
||||
)
|
||||
yield from safetensors_encrypted_weights_iterator(
|
||||
hf_weights_files, is_all_weights_sharded, decryption_key
|
||||
)
|
||||
return
|
||||
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
|
||||
def _load_file(st_file: str):
|
||||
if disable_mmap:
|
||||
with open(st_file, "rb") as f:
|
||||
result = safetensors.torch.load(f.read())
|
||||
else:
|
||||
result = safetensors.torch.load_file(st_file, device="cpu")
|
||||
|
||||
return result
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
|
||||
|
||||
if enable_tqdm:
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
desc="Multi-thread loading shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
else:
|
||||
futures_iter = concurrent.futures.as_completed(futures)
|
||||
|
||||
for future in futures_iter:
|
||||
state_dict = future.result()
|
||||
for name, param in state_dict.items():
|
||||
yield name, param
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
@@ -471,6 +527,39 @@ def pt_weights_iterator(
|
||||
del state
|
||||
|
||||
|
||||
def multi_thread_pt_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
max_workers: int = 4,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
|
||||
def _load_file(bin_file: str):
|
||||
return torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
|
||||
]
|
||||
|
||||
if enable_tqdm:
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
desc="Multi-thread loading pt checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
else:
|
||||
futures_iter = concurrent.futures.as_completed(futures)
|
||||
|
||||
for future in futures_iter:
|
||||
state = future.result()
|
||||
yield from state.items()
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> List[str]:
|
||||
|
||||
Reference in New Issue
Block a user