From 9f1787fa60db52e6f5a35eba0f4aee161214ef34 Mon Sep 17 00:00:00 2001 From: xianzhiT Date: Wed, 25 Jun 2025 01:39:10 +0800 Subject: [PATCH] Support multi-thread model weight loading (#7277) --- .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/model_loader/loader.py | 55 +++++++++--- .../sglang/srt/model_loader/weight_utils.py | 89 +++++++++++++++++++ python/sglang/srt/server_args.py | 8 ++ 4 files changed, 143 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2743fe51e..bd6a027d5 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -547,6 +547,7 @@ class ModelRunner: self.load_config = LoadConfig( load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, + model_loader_extra_config=self.server_args.model_loader_extra_config, ) if self.server_args.load_format == "gguf": monkey_patch_vllm_gguf_config() diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 0aebe2f9f..29c82b084 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -2,6 +2,7 @@ # ruff: noqa: SIM117 import collections +import concurrent import dataclasses import fnmatch import glob @@ -11,14 +12,17 @@ import math import os import time from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast import huggingface_hub import numpy as np +import safetensors.torch import torch from huggingface_hub import HfApi, hf_hub_download from torch import nn +from tqdm.auto import tqdm from transformers import AutoModelForCausalLM from transformers.utils import SAFE_WEIGHTS_INDEX_NAME @@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import ( set_default_torch_dtype, ) from sglang.srt.model_loader.weight_utils import ( + _BAR_FORMAT, download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, @@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import ( get_quant_config, gguf_quant_weights_iterator, initialize_dummy_weights, + multi_thread_pt_weights_iterator, + multi_thread_safetensors_weights_iterator, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator, @@ -181,6 +188,9 @@ class BaseModelLoader(ABC): class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + # default number of thread when enable multithread weight loading + DEFAULT_NUM_THREADS = 8 + @dataclasses.dataclass class Source: """A source for weights.""" @@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) - if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + allowed_keys = {"enable_multithread_load", "num_threads"} + unexpected_keys = set(extra_config.keys()) - allowed_keys + + if unexpected_keys: raise ValueError( - f"Model loader extra config is not supported for " - f"load format {load_config.load_format}" + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{unexpected_keys}" ) def _maybe_download_from_modelscope( @@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader): self, source: "Source" ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" + extra_config = self.load_config.model_loader_extra_config hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt ) @@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader): weight_loader_disable_mmap = global_server_args_dict.get( "weight_loader_disable_mmap" ) - weights_iterator = safetensors_weights_iterator( - hf_weights_files, disable_mmap=weight_loader_disable_mmap - ) + + if extra_config.get("enable_multithread_load"): + weights_iterator = multi_thread_safetensors_weights_iterator( + hf_weights_files, + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS + ), + disable_mmap=weight_loader_disable_mmap, + ) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, disable_mmap=weight_loader_disable_mmap + ) + else: - weights_iterator = pt_weights_iterator(hf_weights_files) + if extra_config.get("enable_multithread_load"): + weights_iterator = multi_thread_pt_weights_iterator( + hf_weights_files, + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS + ), + ) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) @@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader): self.load_config, ) - self.load_weights_and_postprocess( - model, self._get_all_weights(model_config, model), target_device - ) + self.load_weights_and_postprocess( + model, self._get_all_weights(model_config, model), target_device + ) return model.eval() diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 722f8e1d4..33bc4e152 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -1,12 +1,14 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py """Utilities for downloading and initializing model weights.""" +import concurrent.futures import fnmatch import glob import hashlib import json import logging import os +import queue import tempfile from collections import defaultdict from typing import ( @@ -453,6 +455,60 @@ def safetensors_weights_iterator( yield name, param +def multi_thread_safetensors_weights_iterator( + hf_weights_files: List[str], + is_all_weights_sharded: bool = False, + decryption_key: Optional[str] = None, + max_workers: int = 4, + disable_mmap: bool = False, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model safetensor files. + + If is_all_weights_sharded is True, it uses more optimize read by reading an + entire file instead of reading each tensor one by one. + """ + if decryption_key: + logger.warning( + "Multi-Thread loading is not working for encrypted safetensor weights." + ) + yield from safetensors_encrypted_weights_iterator( + hf_weights_files, is_all_weights_sharded, decryption_key + ) + return + + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + + def _load_file(st_file: str): + if disable_mmap: + with open(st_file, "rb") as f: + result = safetensors.torch.load(f.read()) + else: + result = safetensors.torch.load_file(st_file, device="cpu") + + return result + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files] + + if enable_tqdm: + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ) + else: + futures_iter = concurrent.futures.as_completed(futures) + + for future in futures_iter: + state_dict = future.result() + for name, param in state_dict.items(): + yield name, param + + def pt_weights_iterator( hf_weights_files: List[str], ) -> Generator[Tuple[str, torch.Tensor], None, None]: @@ -471,6 +527,39 @@ def pt_weights_iterator( del state +def multi_thread_pt_weights_iterator( + hf_weights_files: List[str], + max_workers: int = 4, +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model bin/pt files.""" + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + + def _load_file(bin_file: str): + return torch.load(bin_file, map_location="cpu", weights_only=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_file, bin_file) for bin_file in hf_weights_files + ] + + if enable_tqdm: + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ) + else: + futures_iter = concurrent.futures.as_completed(futures) + + for future in futures_iter: + state = future.result() + yield from state.items() + + def get_gguf_extra_tensor_names( gguf_file: str, gguf_to_hf_name_map: Dict[str, str] ) -> List[str]: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e5b1c1809..2014f7e95 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -47,6 +47,7 @@ class ServerArgs: tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False load_format: str = "auto" + model_loader_extra_config: str = "{}" trust_remote_code: bool = False dtype: str = "auto" kv_cache_dtype: str = "auto" @@ -632,6 +633,13 @@ class ServerArgs: "layer before loading another to make the peak memory envelope " "smaller.", ) + parser.add_argument( + "--model-loader-extra-config", + type=str, + help="Extra config for model loader. " + "This will be passed to the model loader corresponding to the chosen load_format.", + default=ServerArgs.model_loader_extra_config, + ) parser.add_argument( "--trust-remote-code", action="store_true",