[Misc] Fix issues reported by torchfix (#4837)
This commit is contained in:
@@ -92,7 +92,7 @@ def convert_bin_to_safetensor_file(
|
|||||||
pt_filename: str,
|
pt_filename: str,
|
||||||
sf_filename: str,
|
sf_filename: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
loaded = torch.load(pt_filename, map_location="cpu")
|
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
|
||||||
if "state_dict" in loaded:
|
if "state_dict" in loaded:
|
||||||
loaded = loaded["state_dict"]
|
loaded = loaded["state_dict"]
|
||||||
shared = _shared_pointers(loaded)
|
shared = _shared_pointers(loaded)
|
||||||
@@ -380,7 +380,7 @@ def np_cache_weights_iterator(
|
|||||||
disable=not enable_tqdm,
|
disable=not enable_tqdm,
|
||||||
bar_format=_BAR_FORMAT,
|
bar_format=_BAR_FORMAT,
|
||||||
):
|
):
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
param_path = os.path.join(np_folder, name)
|
param_path = os.path.join(np_folder, name)
|
||||||
with open(param_path, "wb") as f:
|
with open(param_path, "wb") as f:
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ def resample_patch_embed(
|
|||||||
try:
|
try:
|
||||||
from torch import vmap
|
from torch import vmap
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from functorch import vmap
|
from torch.func import vmap
|
||||||
|
|
||||||
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
||||||
assert len(new_size) == 2, "New shape should only be hw"
|
assert len(new_size) == 2, "New shape should only be hw"
|
||||||
@@ -1084,7 +1084,7 @@ def create_siglip_vit(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if ckpt_path:
|
if ckpt_path:
|
||||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||||
print(
|
print(
|
||||||
|
|||||||
@@ -586,5 +586,5 @@ def load_token_map(token_map_path: str) -> List[int]:
|
|||||||
ignore_patterns=["*.bin", "*.safetensors"],
|
ignore_patterns=["*.bin", "*.safetensors"],
|
||||||
)
|
)
|
||||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||||
hot_token_id = torch.load(token_map_path)
|
hot_token_id = torch.load(token_map_path, weights_only=True)
|
||||||
return torch.tensor(hot_token_id, dtype=torch.int32)
|
return torch.tensor(hot_token_id, dtype=torch.int32)
|
||||||
|
|||||||
Reference in New Issue
Block a user