Support multi-thread model weight loading (#7277)
This commit is contained in:
@@ -547,6 +547,7 @@ class ModelRunner:
|
|||||||
self.load_config = LoadConfig(
|
self.load_config = LoadConfig(
|
||||||
load_format=self.server_args.load_format,
|
load_format=self.server_args.load_format,
|
||||||
download_dir=self.server_args.download_dir,
|
download_dir=self.server_args.download_dir,
|
||||||
|
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
||||||
)
|
)
|
||||||
if self.server_args.load_format == "gguf":
|
if self.server_args.load_format == "gguf":
|
||||||
monkey_patch_vllm_gguf_config()
|
monkey_patch_vllm_gguf_config()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
# ruff: noqa: SIM117
|
# ruff: noqa: SIM117
|
||||||
import collections
|
import collections
|
||||||
|
import concurrent
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import glob
|
import glob
|
||||||
@@ -11,14 +12,17 @@ import math
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
|
|||||||
set_default_torch_dtype,
|
set_default_torch_dtype,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
|
_BAR_FORMAT,
|
||||||
download_safetensors_index_file_from_hf,
|
download_safetensors_index_file_from_hf,
|
||||||
download_weights_from_hf,
|
download_weights_from_hf,
|
||||||
filter_duplicate_safetensors_files,
|
filter_duplicate_safetensors_files,
|
||||||
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
get_quant_config,
|
get_quant_config,
|
||||||
gguf_quant_weights_iterator,
|
gguf_quant_weights_iterator,
|
||||||
initialize_dummy_weights,
|
initialize_dummy_weights,
|
||||||
|
multi_thread_pt_weights_iterator,
|
||||||
|
multi_thread_safetensors_weights_iterator,
|
||||||
np_cache_weights_iterator,
|
np_cache_weights_iterator,
|
||||||
pt_weights_iterator,
|
pt_weights_iterator,
|
||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
|
|||||||
class DefaultModelLoader(BaseModelLoader):
|
class DefaultModelLoader(BaseModelLoader):
|
||||||
"""Model loader that can load different file types from disk."""
|
"""Model loader that can load different file types from disk."""
|
||||||
|
|
||||||
|
# default number of thread when enable multithread weight loading
|
||||||
|
DEFAULT_NUM_THREADS = 8
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Source:
|
class Source:
|
||||||
"""A source for weights."""
|
"""A source for weights."""
|
||||||
@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
if load_config.model_loader_extra_config:
|
extra_config = load_config.model_loader_extra_config
|
||||||
|
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||||
|
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||||
|
|
||||||
|
if unexpected_keys:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model loader extra config is not supported for "
|
f"Unexpected extra config keys for load format "
|
||||||
f"load format {load_config.load_format}"
|
f"{load_config.load_format}: "
|
||||||
|
f"{unexpected_keys}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _maybe_download_from_modelscope(
|
def _maybe_download_from_modelscope(
|
||||||
@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self, source: "Source"
|
self, source: "Source"
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
"""Get an iterator for the model weights based on the load format."""
|
"""Get an iterator for the model weights based on the load format."""
|
||||||
|
extra_config = self.load_config.model_loader_extra_config
|
||||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||||
source.model_or_path, source.revision, source.fall_back_to_pt
|
source.model_or_path, source.revision, source.fall_back_to_pt
|
||||||
)
|
)
|
||||||
@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
weight_loader_disable_mmap = global_server_args_dict.get(
|
weight_loader_disable_mmap = global_server_args_dict.get(
|
||||||
"weight_loader_disable_mmap"
|
"weight_loader_disable_mmap"
|
||||||
)
|
)
|
||||||
weights_iterator = safetensors_weights_iterator(
|
|
||||||
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
if extra_config.get("enable_multithread_load"):
|
||||||
)
|
weights_iterator = multi_thread_safetensors_weights_iterator(
|
||||||
|
hf_weights_files,
|
||||||
|
max_workers=extra_config.get(
|
||||||
|
"num_threads", self.DEFAULT_NUM_THREADS
|
||||||
|
),
|
||||||
|
disable_mmap=weight_loader_disable_mmap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weights_iterator = safetensors_weights_iterator(
|
||||||
|
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
if extra_config.get("enable_multithread_load"):
|
||||||
|
weights_iterator = multi_thread_pt_weights_iterator(
|
||||||
|
hf_weights_files,
|
||||||
|
max_workers=extra_config.get(
|
||||||
|
"num_threads", self.DEFAULT_NUM_THREADS
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weights_iterator = pt_weights_iterator(hf_weights_files)
|
||||||
|
|
||||||
# Apply the prefix.
|
# Apply the prefix.
|
||||||
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
||||||
@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self.load_config,
|
self.load_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.load_weights_and_postprocess(
|
self.load_weights_and_postprocess(
|
||||||
model, self._get_all_weights(model_config, model), target_device
|
model, self._get_all_weights(model_config, model), target_device
|
||||||
)
|
)
|
||||||
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
|
||||||
|
|
||||||
"""Utilities for downloading and initializing model weights."""
|
"""Utilities for downloading and initializing model weights."""
|
||||||
|
import concurrent.futures
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import glob
|
import glob
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
import tempfile
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
|
|||||||
yield name, param
|
yield name, param
|
||||||
|
|
||||||
|
|
||||||
|
def multi_thread_safetensors_weights_iterator(
|
||||||
|
hf_weights_files: List[str],
|
||||||
|
is_all_weights_sharded: bool = False,
|
||||||
|
decryption_key: Optional[str] = None,
|
||||||
|
max_workers: int = 4,
|
||||||
|
disable_mmap: bool = False,
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
"""Multi-Thread iterate over the weights in the model safetensor files.
|
||||||
|
|
||||||
|
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
||||||
|
entire file instead of reading each tensor one by one.
|
||||||
|
"""
|
||||||
|
if decryption_key:
|
||||||
|
logger.warning(
|
||||||
|
"Multi-Thread loading is not working for encrypted safetensor weights."
|
||||||
|
)
|
||||||
|
yield from safetensors_encrypted_weights_iterator(
|
||||||
|
hf_weights_files, is_all_weights_sharded, decryption_key
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
enable_tqdm = (
|
||||||
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_file(st_file: str):
|
||||||
|
if disable_mmap:
|
||||||
|
with open(st_file, "rb") as f:
|
||||||
|
result = safetensors.torch.load(f.read())
|
||||||
|
else:
|
||||||
|
result = safetensors.torch.load_file(st_file, device="cpu")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
|
||||||
|
|
||||||
|
if enable_tqdm:
|
||||||
|
futures_iter = tqdm(
|
||||||
|
concurrent.futures.as_completed(futures),
|
||||||
|
total=len(hf_weights_files),
|
||||||
|
desc="Multi-thread loading shards",
|
||||||
|
disable=not enable_tqdm,
|
||||||
|
bar_format=_BAR_FORMAT,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
futures_iter = concurrent.futures.as_completed(futures)
|
||||||
|
|
||||||
|
for future in futures_iter:
|
||||||
|
state_dict = future.result()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
yield name, param
|
||||||
|
|
||||||
|
|
||||||
def pt_weights_iterator(
|
def pt_weights_iterator(
|
||||||
hf_weights_files: List[str],
|
hf_weights_files: List[str],
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
@@ -471,6 +527,39 @@ def pt_weights_iterator(
|
|||||||
del state
|
del state
|
||||||
|
|
||||||
|
|
||||||
|
def multi_thread_pt_weights_iterator(
|
||||||
|
hf_weights_files: List[str],
|
||||||
|
max_workers: int = 4,
|
||||||
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
||||||
|
enable_tqdm = (
|
||||||
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_file(bin_file: str):
|
||||||
|
return torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
|
||||||
|
]
|
||||||
|
|
||||||
|
if enable_tqdm:
|
||||||
|
futures_iter = tqdm(
|
||||||
|
concurrent.futures.as_completed(futures),
|
||||||
|
total=len(hf_weights_files),
|
||||||
|
desc="Multi-thread loading pt checkpoint shards",
|
||||||
|
disable=not enable_tqdm,
|
||||||
|
bar_format=_BAR_FORMAT,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
futures_iter = concurrent.futures.as_completed(futures)
|
||||||
|
|
||||||
|
for future in futures_iter:
|
||||||
|
state = future.result()
|
||||||
|
yield from state.items()
|
||||||
|
|
||||||
|
|
||||||
def get_gguf_extra_tensor_names(
|
def get_gguf_extra_tensor_names(
|
||||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class ServerArgs:
|
|||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
|
model_loader_extra_config: str = "{}"
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
kv_cache_dtype: str = "auto"
|
kv_cache_dtype: str = "auto"
|
||||||
@@ -632,6 +633,13 @@ class ServerArgs:
|
|||||||
"layer before loading another to make the peak memory envelope "
|
"layer before loading another to make the peak memory envelope "
|
||||||
"smaller.",
|
"smaller.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-loader-extra-config",
|
||||||
|
type=str,
|
||||||
|
help="Extra config for model loader. "
|
||||||
|
"This will be passed to the model loader corresponding to the chosen load_format.",
|
||||||
|
default=ServerArgs.model_loader_extra_config,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user