Support multi-thread model weight loading (#7277)
This commit is contained in:
@@ -547,6 +547,7 @@ class ModelRunner:
|
||||
self.load_config = LoadConfig(
|
||||
load_format=self.server_args.load_format,
|
||||
download_dir=self.server_args.download_dir,
|
||||
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
||||
)
|
||||
if self.server_args.load_format == "gguf":
|
||||
monkey_patch_vllm_gguf_config()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
# ruff: noqa: SIM117
|
||||
import collections
|
||||
import concurrent
|
||||
import dataclasses
|
||||
import fnmatch
|
||||
import glob
|
||||
@@ -11,14 +12,17 @@ import math
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from torch import nn
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
|
||||
set_default_torch_dtype,
|
||||
)
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
_BAR_FORMAT,
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files,
|
||||
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
get_quant_config,
|
||||
gguf_quant_weights_iterator,
|
||||
initialize_dummy_weights,
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator,
|
||||
np_cache_weights_iterator,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
# default number of thread when enable multithread weight loading
|
||||
DEFAULT_NUM_THREADS = 8
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Source:
|
||||
"""A source for weights."""
|
||||
@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||
|
||||
if unexpected_keys:
|
||||
raise ValueError(
|
||||
f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{unexpected_keys}"
|
||||
)
|
||||
|
||||
def _maybe_download_from_modelscope(
|
||||
@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self, source: "Source"
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
extra_config = self.load_config.model_loader_extra_config
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt
|
||||
)
|
||||
@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
weight_loader_disable_mmap = global_server_args_dict.get(
|
||||
"weight_loader_disable_mmap"
|
||||
)
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
||||
)
|
||||
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
disable_mmap=weight_loader_disable_mmap,
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
||||
)
|
||||
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
||||
|
||||
# Apply the prefix.
|
||||
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
||||
@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config,
|
||||
)
|
||||
|
||||
self.load_weights_and_postprocess(
|
||||
model, self._get_all_weights(model_config, model), target_device
|
||||
)
|
||||
self.load_weights_and_postprocess(
|
||||
model, self._get_all_weights(model_config, model), target_device
|
||||
)
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
|
||||
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
|
||||
yield name, param
|
||||
|
||||
|
||||
def multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
is_all_weights_sharded: bool = False,
|
||||
decryption_key: Optional[str] = None,
|
||||
max_workers: int = 4,
|
||||
disable_mmap: bool = False,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Multi-Thread iterate over the weights in the model safetensor files.
|
||||
|
||||
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
||||
entire file instead of reading each tensor one by one.
|
||||
"""
|
||||
if decryption_key:
|
||||
logger.warning(
|
||||
"Multi-Thread loading is not working for encrypted safetensor weights."
|
||||
)
|
||||
yield from safetensors_encrypted_weights_iterator(
|
||||
hf_weights_files, is_all_weights_sharded, decryption_key
|
||||
)
|
||||
return
|
||||
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
|
||||
def _load_file(st_file: str):
|
||||
if disable_mmap:
|
||||
with open(st_file, "rb") as f:
|
||||
result = safetensors.torch.load(f.read())
|
||||
else:
|
||||
result = safetensors.torch.load_file(st_file, device="cpu")
|
||||
|
||||
return result
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
|
||||
|
||||
if enable_tqdm:
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
desc="Multi-thread loading shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
else:
|
||||
futures_iter = concurrent.futures.as_completed(futures)
|
||||
|
||||
for future in futures_iter:
|
||||
state_dict = future.result()
|
||||
for name, param in state_dict.items():
|
||||
yield name, param
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
@@ -471,6 +527,39 @@ def pt_weights_iterator(
|
||||
del state
|
||||
|
||||
|
||||
def multi_thread_pt_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
max_workers: int = 4,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
|
||||
def _load_file(bin_file: str):
|
||||
return torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
|
||||
]
|
||||
|
||||
if enable_tqdm:
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
desc="Multi-thread loading pt checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
else:
|
||||
futures_iter = concurrent.futures.as_completed(futures)
|
||||
|
||||
for future in futures_iter:
|
||||
state = future.result()
|
||||
yield from state.items()
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> List[str]:
|
||||
|
||||
@@ -47,6 +47,7 @@ class ServerArgs:
|
||||
tokenizer_mode: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
load_format: str = "auto"
|
||||
model_loader_extra_config: str = "{}"
|
||||
trust_remote_code: bool = False
|
||||
dtype: str = "auto"
|
||||
kv_cache_dtype: str = "auto"
|
||||
@@ -632,6 +633,13 @@ class ServerArgs:
|
||||
"layer before loading another to make the peak memory envelope "
|
||||
"smaller.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-loader-extra-config",
|
||||
type=str,
|
||||
help="Extra config for model loader. "
|
||||
"This will be passed to the model loader corresponding to the chosen load_format.",
|
||||
default=ServerArgs.model_loader_extra_config,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user