fix: check if weights are already local before downloading (#11015)
This commit is contained in:
@@ -8,7 +8,6 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import queue
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -38,7 +37,8 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|||||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
||||||
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
||||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
||||||
from sglang.srt.utils import print_warning_once
|
from sglang.srt.utils import find_local_repo_dir, print_warning_once
|
||||||
|
from sglang.utils import is_in_ci
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -236,6 +236,92 @@ def get_quant_config(
|
|||||||
return quant_cls.from_config(config)
|
return quant_cls.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def find_local_hf_snapshot_dir(
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str],
|
||||||
|
allow_patterns: List[str],
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""If the weights are already local, skip downloading and returns the path
|
||||||
|
|
||||||
|
Only applied in ci
|
||||||
|
"""
|
||||||
|
if not is_in_ci() or os.path.isdir(model_name_or_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
found_local_snapshot_dir = None
|
||||||
|
|
||||||
|
# Check custom cache_dir (if provided)
|
||||||
|
if cache_dir:
|
||||||
|
try:
|
||||||
|
repo_folder = os.path.join(
|
||||||
|
cache_dir,
|
||||||
|
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
|
||||||
|
["models", *model_name_or_path.split("/")]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
rev_to_use = revision
|
||||||
|
if not rev_to_use:
|
||||||
|
ref_main = os.path.join(repo_folder, "refs", "main")
|
||||||
|
if os.path.isfile(ref_main):
|
||||||
|
with open(ref_main) as f:
|
||||||
|
rev_to_use = f.read().strip()
|
||||||
|
if rev_to_use:
|
||||||
|
rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
|
||||||
|
if os.path.isdir(rev_dir):
|
||||||
|
found_local_snapshot_dir = rev_dir
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to find local snapshot in custom cache_dir %s: %s",
|
||||||
|
cache_dir,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check default HF cache as well
|
||||||
|
if not found_local_snapshot_dir:
|
||||||
|
try:
|
||||||
|
rev_dir = find_local_repo_dir(model_name_or_path, revision)
|
||||||
|
if rev_dir and os.path.isdir(rev_dir):
|
||||||
|
found_local_snapshot_dir = rev_dir
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
|
||||||
|
|
||||||
|
# If local snapshot exists, validate it contains at least one weight file
|
||||||
|
# matching allow_patterns before skipping download.
|
||||||
|
if found_local_snapshot_dir is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
local_weight_files: List[str] = []
|
||||||
|
try:
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
local_weight_files.extend(
|
||||||
|
glob.glob(os.path.join(found_local_snapshot_dir, pattern))
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to scan local snapshot %s with patterns %s: %s",
|
||||||
|
found_local_snapshot_dir,
|
||||||
|
allow_patterns,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
local_weight_files = []
|
||||||
|
|
||||||
|
if len(local_weight_files) > 0:
|
||||||
|
logger.info(
|
||||||
|
"Found local HF snapshot for %s at %s; skipping download.",
|
||||||
|
model_name_or_path,
|
||||||
|
found_local_snapshot_dir,
|
||||||
|
)
|
||||||
|
return found_local_snapshot_dir
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Local HF snapshot at %s has no files matching %s; will attempt download.",
|
||||||
|
found_local_snapshot_dir,
|
||||||
|
allow_patterns,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def download_weights_from_hf(
|
def download_weights_from_hf(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str],
|
cache_dir: Optional[str],
|
||||||
@@ -260,6 +346,13 @@ def download_weights_from_hf(
|
|||||||
Returns:
|
Returns:
|
||||||
str: The path to the downloaded model weights.
|
str: The path to the downloaded model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
path = find_local_hf_snapshot_dir(
|
||||||
|
model_name_or_path, cache_dir, allow_patterns, revision
|
||||||
|
)
|
||||||
|
if path is not None:
|
||||||
|
return path
|
||||||
|
|
||||||
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
||||||
# Before we download we look at that is available:
|
# Before we download we look at that is available:
|
||||||
fs = HfFileSystem()
|
fs = HfFileSystem()
|
||||||
|
|||||||
Reference in New Issue
Block a user