44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
|
|
import json
|
||
|
|
from safetensors import safe_open
|
||
|
|
|
||
|
|
def generate_safetensors_index(model_path="."):
|
||
|
|
"""Generate model.safetensors.index.json from existing safetensors files"""
|
||
|
|
|
||
|
|
# Load the existing bin index as reference
|
||
|
|
with open(f"pytorch_model.bin.index.json", "r") as f:
|
||
|
|
bin_index = json.load(f)
|
||
|
|
|
||
|
|
# Initialize the safetensors index structure
|
||
|
|
safetensors_index = {
|
||
|
|
"metadata": bin_index.get("metadata", {}),
|
||
|
|
"weight_map": {}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Map each safetensors file and get its tensor names
|
||
|
|
safetensors_files = [
|
||
|
|
"pytorch_model-00001-of-00004.safetensors",
|
||
|
|
"pytorch_model-00002-of-00004.safetensors",
|
||
|
|
"pytorch_model-00003-of-00004.safetensors",
|
||
|
|
"pytorch_model-00004-of-00004.safetensors"
|
||
|
|
]
|
||
|
|
|
||
|
|
for safetensor_file in safetensors_files:
|
||
|
|
try:
|
||
|
|
with safe_open(f"{safetensor_file}", framework="pt") as f:
|
||
|
|
for tensor_name in f.keys():
|
||
|
|
safetensors_index["weight_map"][tensor_name] = safetensor_file
|
||
|
|
print(f"✓ Processed {safetensor_file}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"✗ Error processing {safetensor_file}: {e}")
|
||
|
|
|
||
|
|
# Save the index file
|
||
|
|
with open(f"model.safetensors.index.json", "w") as f:
|
||
|
|
json.dump(safetensors_index, f, indent=2)
|
||
|
|
|
||
|
|
print(f"✓ Generated model.safetensors.index.json with {len(safetensors_index['weight_map'])} tensors")
|
||
|
|
return safetensors_index
|
||
|
|
|
||
|
|
# Run the function
|
||
|
|
if __name__ == "__main__":
|
||
|
|
# Change this path to your model directory if needed
|
||
|
|
generate_safetensors_index("./Finance-Llama-8B")
|