Support multi-thread model weight loading (#7277)

This commit is contained in:
xianzhiT
2025-06-25 01:39:10 +08:00
committed by GitHub
parent 8ecad0b16f
commit 9f1787fa60
4 changed files with 143 additions and 10 deletions

View File

@@ -1,12 +1,14 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import concurrent.futures
import fnmatch
import glob
import hashlib
import json
import logging
import os
import queue
import tempfile
from collections import defaultdict
from typing import (
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
yield name, param
def multi_thread_safetensors_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
max_workers: int = 4,
disable_mmap: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one.
"""
if decryption_key:
logger.warning(
"Multi-Thread loading is not working for encrypted safetensor weights."
)
yield from safetensors_encrypted_weights_iterator(
hf_weights_files, is_all_weights_sharded, decryption_key
)
return
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
def _load_file(st_file: str):
if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.torch.load(f.read())
else:
result = safetensors.torch.load_file(st_file, device="cpu")
return result
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state_dict = future.result()
for name, param in state_dict.items():
yield name, param
def pt_weights_iterator(
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
@@ -471,6 +527,39 @@ def pt_weights_iterator(
del state
def multi_thread_pt_weights_iterator(
hf_weights_files: List[str],
max_workers: int = 4,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model bin/pt files."""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
def _load_file(bin_file: str):
return torch.load(bin_file, map_location="cpu", weights_only=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state = future.result()
yield from state.items()
def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> List[str]: