ci: make find_local_hf_snapshot_dir more robust (#11248)
This commit is contained in:
@@ -8,6 +8,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -283,7 +284,24 @@ def find_local_hf_snapshot_dir(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
|
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
|
||||||
|
|
||||||
# If local snapshot exists, validate it contains at least one weight file
|
# if any incomplete file exists, force re-download by returning None
|
||||||
|
if found_local_snapshot_dir:
|
||||||
|
repo_folder = os.path.abspath(
|
||||||
|
os.path.join(found_local_snapshot_dir, "..", "..")
|
||||||
|
)
|
||||||
|
blobs_dir = os.path.join(repo_folder, "blobs")
|
||||||
|
if os.path.isdir(blobs_dir) and glob.glob(
|
||||||
|
os.path.join(blobs_dir, "*.incomplete")
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Found .incomplete files in %s for %s. "
|
||||||
|
"Considering local snapshot incomplete.",
|
||||||
|
blobs_dir,
|
||||||
|
model_name_or_path,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# if local snapshot exists, validate it contains at least one weight file
|
||||||
# matching allow_patterns before skipping download.
|
# matching allow_patterns before skipping download.
|
||||||
if found_local_snapshot_dir is None:
|
if found_local_snapshot_dir is None:
|
||||||
return None
|
return None
|
||||||
@@ -291,9 +309,12 @@ def find_local_hf_snapshot_dir(
|
|||||||
local_weight_files: List[str] = []
|
local_weight_files: List[str] = []
|
||||||
try:
|
try:
|
||||||
for pattern in allow_patterns:
|
for pattern in allow_patterns:
|
||||||
local_weight_files.extend(
|
matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
|
||||||
glob.glob(os.path.join(found_local_snapshot_dir, pattern))
|
for f in matched_files:
|
||||||
)
|
# os.path.exists returns False for broken symlinks.
|
||||||
|
if not os.path.exists(f):
|
||||||
|
continue
|
||||||
|
local_weight_files.append(f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to scan local snapshot %s with patterns %s: %s",
|
"Failed to scan local snapshot %s with patterns %s: %s",
|
||||||
@@ -303,6 +324,46 @@ def find_local_hf_snapshot_dir(
|
|||||||
)
|
)
|
||||||
local_weight_files = []
|
local_weight_files = []
|
||||||
|
|
||||||
|
# After we have a list of valid files, check for sharded model completeness.
|
||||||
|
# Check if all safetensors with name model-{i}-of-{n}.safetensors exists
|
||||||
|
checked_sharded_model = False
|
||||||
|
for f in local_weight_files:
|
||||||
|
if checked_sharded_model:
|
||||||
|
break
|
||||||
|
base_name = os.path.basename(f)
|
||||||
|
# Regex for files like model-00001-of-00009.safetensors
|
||||||
|
match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
|
||||||
|
if match:
|
||||||
|
prefix = match.group(1)
|
||||||
|
shard_id_str = match.group(2)
|
||||||
|
total_shards_str = match.group(3)
|
||||||
|
suffix = match.group(4)
|
||||||
|
total_shards = int(total_shards_str)
|
||||||
|
|
||||||
|
# Check if all shards are present
|
||||||
|
missing_shards = []
|
||||||
|
for i in range(1, total_shards + 1):
|
||||||
|
# Reconstruct shard name, preserving padding of original shard id
|
||||||
|
shard_name = (
|
||||||
|
f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
|
||||||
|
)
|
||||||
|
expected_path = os.path.join(found_local_snapshot_dir, shard_name)
|
||||||
|
# os.path.exists returns False for broken symlinks, which is desired.
|
||||||
|
if not os.path.exists(expected_path):
|
||||||
|
missing_shards.append(shard_name)
|
||||||
|
|
||||||
|
if missing_shards:
|
||||||
|
logger.info(
|
||||||
|
"Found incomplete sharded model %s. Missing shards: %s. "
|
||||||
|
"Will attempt download.",
|
||||||
|
model_name_or_path,
|
||||||
|
missing_shards,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If we found and verified one set of shards, we are done.
|
||||||
|
checked_sharded_model = True
|
||||||
|
|
||||||
if len(local_weight_files) > 0:
|
if len(local_weight_files) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Found local HF snapshot for %s at %s; skipping download.",
|
"Found local HF snapshot for %s at %s; skipping download.",
|
||||||
|
|||||||
Reference in New Issue
Block a user