From 97d966a7f8d87a3238a5b1197f51db3ec53f4771 Mon Sep 17 00:00:00 2001 From: Mick Date: Mon, 6 Oct 2025 10:50:11 +0800 Subject: [PATCH] ci: make find_local_hf_snapshot_dir more robust (#11248) --- .../sglang/srt/model_loader/weight_utils.py | 69 +++++++++++++++++-- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 44297d687..77bc0103f 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -8,6 +8,7 @@ import hashlib import json import logging import os +import re import tempfile from collections import defaultdict from typing import ( @@ -283,7 +284,24 @@ def find_local_hf_snapshot_dir( except Exception as e: logger.warning("Failed to find local snapshot in default HF cache: %s", e) - # If local snapshot exists, validate it contains at least one weight file + # if any incomplete file exists, force re-download by returning None + if found_local_snapshot_dir: + repo_folder = os.path.abspath( + os.path.join(found_local_snapshot_dir, "..", "..") + ) + blobs_dir = os.path.join(repo_folder, "blobs") + if os.path.isdir(blobs_dir) and glob.glob( + os.path.join(blobs_dir, "*.incomplete") + ): + logger.info( + "Found .incomplete files in %s for %s. " + "Considering local snapshot incomplete.", + blobs_dir, + model_name_or_path, + ) + return None + + # if local snapshot exists, validate it contains at least one weight file # matching allow_patterns before skipping download. if found_local_snapshot_dir is None: return None @@ -291,9 +309,12 @@ def find_local_hf_snapshot_dir( local_weight_files: List[str] = [] try: for pattern in allow_patterns: - local_weight_files.extend( - glob.glob(os.path.join(found_local_snapshot_dir, pattern)) - ) + matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern)) + for f in matched_files: + # os.path.exists returns False for broken symlinks. + if not os.path.exists(f): + continue + local_weight_files.append(f) except Exception as e: logger.warning( "Failed to scan local snapshot %s with patterns %s: %s", @@ -303,6 +324,46 @@ def find_local_hf_snapshot_dir( ) local_weight_files = [] + # After we have a list of valid files, check for sharded model completeness. + # Check if all safetensors with name model-{i}-of-{n}.safetensors exists + checked_sharded_model = False + for f in local_weight_files: + if checked_sharded_model: + break + base_name = os.path.basename(f) + # Regex for files like model-00001-of-00009.safetensors + match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name) + if match: + prefix = match.group(1) + shard_id_str = match.group(2) + total_shards_str = match.group(3) + suffix = match.group(4) + total_shards = int(total_shards_str) + + # Check if all shards are present + missing_shards = [] + for i in range(1, total_shards + 1): + # Reconstruct shard name, preserving padding of original shard id + shard_name = ( + f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}" + ) + expected_path = os.path.join(found_local_snapshot_dir, shard_name) + # os.path.exists returns False for broken symlinks, which is desired. + if not os.path.exists(expected_path): + missing_shards.append(shard_name) + + if missing_shards: + logger.info( + "Found incomplete sharded model %s. Missing shards: %s. " + "Will attempt download.", + model_name_or_path, + missing_shards, + ) + return None + + # If we found and verified one set of shards, we are done. + checked_sharded_model = True + if len(local_weight_files) > 0: logger.info( "Found local HF snapshot for %s at %s; skipping download.",