From 43fa9f22bd0ccf533f975886378ec0c288be0fa1 Mon Sep 17 00:00:00 2001 From: Mick Date: Mon, 29 Sep 2025 11:11:33 +0800 Subject: [PATCH] fix: check if weights are already local before downloading (#11015) --- .../sglang/srt/model_loader/weight_utils.py | 97 ++++++++++++++++++- 1 file changed, 95 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 397d9e913..0e7089bfc 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -8,7 +8,6 @@ import hashlib import json import logging import os -import queue import tempfile from collections import defaultdict from typing import ( @@ -38,7 +37,8 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.dp_attention import get_attention_tp_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config -from sglang.srt.utils import print_warning_once +from sglang.srt.utils import find_local_repo_dir, print_warning_once +from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -236,6 +236,92 @@ def get_quant_config( return quant_cls.from_config(config) +def find_local_hf_snapshot_dir( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, +) -> Optional[str]: + """If the weights are already local, skip downloading and returns the path + + Only applied in ci + """ + if not is_in_ci() or os.path.isdir(model_name_or_path): + return None + + found_local_snapshot_dir = None + + # Check custom cache_dir (if provided) + if cache_dir: + try: + repo_folder = os.path.join( + cache_dir, + huggingface_hub.constants.REPO_ID_SEPARATOR.join( + ["models", *model_name_or_path.split("/")] + ), + ) + rev_to_use = revision + if not rev_to_use: + ref_main = os.path.join(repo_folder, "refs", "main") + if os.path.isfile(ref_main): + with open(ref_main) as f: + rev_to_use = f.read().strip() + if rev_to_use: + rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use) + if os.path.isdir(rev_dir): + found_local_snapshot_dir = rev_dir + except Exception as e: + logger.warning( + "Failed to find local snapshot in custom cache_dir %s: %s", + cache_dir, + e, + ) + + # Check default HF cache as well + if not found_local_snapshot_dir: + try: + rev_dir = find_local_repo_dir(model_name_or_path, revision) + if rev_dir and os.path.isdir(rev_dir): + found_local_snapshot_dir = rev_dir + except Exception as e: + logger.warning("Failed to find local snapshot in default HF cache: %s", e) + + # If local snapshot exists, validate it contains at least one weight file + # matching allow_patterns before skipping download. + if found_local_snapshot_dir is None: + return None + + local_weight_files: List[str] = [] + try: + for pattern in allow_patterns: + local_weight_files.extend( + glob.glob(os.path.join(found_local_snapshot_dir, pattern)) + ) + except Exception as e: + logger.warning( + "Failed to scan local snapshot %s with patterns %s: %s", + found_local_snapshot_dir, + allow_patterns, + e, + ) + local_weight_files = [] + + if len(local_weight_files) > 0: + logger.info( + "Found local HF snapshot for %s at %s; skipping download.", + model_name_or_path, + found_local_snapshot_dir, + ) + return found_local_snapshot_dir + else: + logger.info( + "Local HF snapshot at %s has no files matching %s; will attempt download.", + found_local_snapshot_dir, + allow_patterns, + ) + return None + + def download_weights_from_hf( model_name_or_path: str, cache_dir: Optional[str], @@ -260,6 +346,13 @@ def download_weights_from_hf( Returns: str: The path to the downloaded model weights. """ + + path = find_local_hf_snapshot_dir( + model_name_or_path, cache_dir, allow_patterns, revision + ) + if path is not None: + return path + if not huggingface_hub.constants.HF_HUB_OFFLINE: # Before we download we look at that is available: fs = HfFileSystem()