fix: check if weights are already local before downloading (#11015)
This commit is contained in:
@@ -8,7 +8,6 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
@@ -38,7 +37,8 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
||||
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
||||
from sglang.srt.utils import print_warning_once
|
||||
from sglang.srt.utils import find_local_repo_dir, print_warning_once
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -236,6 +236,92 @@ def get_quant_config(
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def find_local_hf_snapshot_dir(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: List[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""If the weights are already local, skip downloading and returns the path
|
||||
|
||||
Only applied in ci
|
||||
"""
|
||||
if not is_in_ci() or os.path.isdir(model_name_or_path):
|
||||
return None
|
||||
|
||||
found_local_snapshot_dir = None
|
||||
|
||||
# Check custom cache_dir (if provided)
|
||||
if cache_dir:
|
||||
try:
|
||||
repo_folder = os.path.join(
|
||||
cache_dir,
|
||||
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
|
||||
["models", *model_name_or_path.split("/")]
|
||||
),
|
||||
)
|
||||
rev_to_use = revision
|
||||
if not rev_to_use:
|
||||
ref_main = os.path.join(repo_folder, "refs", "main")
|
||||
if os.path.isfile(ref_main):
|
||||
with open(ref_main) as f:
|
||||
rev_to_use = f.read().strip()
|
||||
if rev_to_use:
|
||||
rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
|
||||
if os.path.isdir(rev_dir):
|
||||
found_local_snapshot_dir = rev_dir
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to find local snapshot in custom cache_dir %s: %s",
|
||||
cache_dir,
|
||||
e,
|
||||
)
|
||||
|
||||
# Check default HF cache as well
|
||||
if not found_local_snapshot_dir:
|
||||
try:
|
||||
rev_dir = find_local_repo_dir(model_name_or_path, revision)
|
||||
if rev_dir and os.path.isdir(rev_dir):
|
||||
found_local_snapshot_dir = rev_dir
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
|
||||
|
||||
# If local snapshot exists, validate it contains at least one weight file
|
||||
# matching allow_patterns before skipping download.
|
||||
if found_local_snapshot_dir is None:
|
||||
return None
|
||||
|
||||
local_weight_files: List[str] = []
|
||||
try:
|
||||
for pattern in allow_patterns:
|
||||
local_weight_files.extend(
|
||||
glob.glob(os.path.join(found_local_snapshot_dir, pattern))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to scan local snapshot %s with patterns %s: %s",
|
||||
found_local_snapshot_dir,
|
||||
allow_patterns,
|
||||
e,
|
||||
)
|
||||
local_weight_files = []
|
||||
|
||||
if len(local_weight_files) > 0:
|
||||
logger.info(
|
||||
"Found local HF snapshot for %s at %s; skipping download.",
|
||||
model_name_or_path,
|
||||
found_local_snapshot_dir,
|
||||
)
|
||||
return found_local_snapshot_dir
|
||||
else:
|
||||
logger.info(
|
||||
"Local HF snapshot at %s has no files matching %s; will attempt download.",
|
||||
found_local_snapshot_dir,
|
||||
allow_patterns,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
@@ -260,6 +346,13 @@ def download_weights_from_hf(
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
"""
|
||||
|
||||
path = find_local_hf_snapshot_dir(
|
||||
model_name_or_path, cache_dir, allow_patterns, revision
|
||||
)
|
||||
if path is not None:
|
||||
return path
|
||||
|
||||
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
|
||||
Reference in New Issue
Block a user