ci: make find_local_hf_snapshot_dir more robust (#11248)
This commit is contained in:
@@ -8,6 +8,7 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
@@ -283,7 +284,24 @@ def find_local_hf_snapshot_dir(
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
|
||||
|
||||
# If local snapshot exists, validate it contains at least one weight file
|
||||
# if any incomplete file exists, force re-download by returning None
|
||||
if found_local_snapshot_dir:
|
||||
repo_folder = os.path.abspath(
|
||||
os.path.join(found_local_snapshot_dir, "..", "..")
|
||||
)
|
||||
blobs_dir = os.path.join(repo_folder, "blobs")
|
||||
if os.path.isdir(blobs_dir) and glob.glob(
|
||||
os.path.join(blobs_dir, "*.incomplete")
|
||||
):
|
||||
logger.info(
|
||||
"Found .incomplete files in %s for %s. "
|
||||
"Considering local snapshot incomplete.",
|
||||
blobs_dir,
|
||||
model_name_or_path,
|
||||
)
|
||||
return None
|
||||
|
||||
# if local snapshot exists, validate it contains at least one weight file
|
||||
# matching allow_patterns before skipping download.
|
||||
if found_local_snapshot_dir is None:
|
||||
return None
|
||||
@@ -291,9 +309,12 @@ def find_local_hf_snapshot_dir(
|
||||
local_weight_files: List[str] = []
|
||||
try:
|
||||
for pattern in allow_patterns:
|
||||
local_weight_files.extend(
|
||||
glob.glob(os.path.join(found_local_snapshot_dir, pattern))
|
||||
)
|
||||
matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
|
||||
for f in matched_files:
|
||||
# os.path.exists returns False for broken symlinks.
|
||||
if not os.path.exists(f):
|
||||
continue
|
||||
local_weight_files.append(f)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to scan local snapshot %s with patterns %s: %s",
|
||||
@@ -303,6 +324,46 @@ def find_local_hf_snapshot_dir(
|
||||
)
|
||||
local_weight_files = []
|
||||
|
||||
# After we have a list of valid files, check for sharded model completeness.
|
||||
# Check if all safetensors with name model-{i}-of-{n}.safetensors exists
|
||||
checked_sharded_model = False
|
||||
for f in local_weight_files:
|
||||
if checked_sharded_model:
|
||||
break
|
||||
base_name = os.path.basename(f)
|
||||
# Regex for files like model-00001-of-00009.safetensors
|
||||
match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
|
||||
if match:
|
||||
prefix = match.group(1)
|
||||
shard_id_str = match.group(2)
|
||||
total_shards_str = match.group(3)
|
||||
suffix = match.group(4)
|
||||
total_shards = int(total_shards_str)
|
||||
|
||||
# Check if all shards are present
|
||||
missing_shards = []
|
||||
for i in range(1, total_shards + 1):
|
||||
# Reconstruct shard name, preserving padding of original shard id
|
||||
shard_name = (
|
||||
f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
|
||||
)
|
||||
expected_path = os.path.join(found_local_snapshot_dir, shard_name)
|
||||
# os.path.exists returns False for broken symlinks, which is desired.
|
||||
if not os.path.exists(expected_path):
|
||||
missing_shards.append(shard_name)
|
||||
|
||||
if missing_shards:
|
||||
logger.info(
|
||||
"Found incomplete sharded model %s. Missing shards: %s. "
|
||||
"Will attempt download.",
|
||||
model_name_or_path,
|
||||
missing_shards,
|
||||
)
|
||||
return None
|
||||
|
||||
# If we found and verified one set of shards, we are done.
|
||||
checked_sharded_model = True
|
||||
|
||||
if len(local_weight_files) > 0:
|
||||
logger.info(
|
||||
"Found local HF snapshot for %s at %s; skipping download.",
|
||||
|
||||
Reference in New Issue
Block a user