Files
enginex-mthreads-vllm/tests/model_executor/test_weight_utils.py

67 lines
2.0 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2026-01-09 13:34:11 +08:00
import os
import tempfile
import huggingface_hub.constants
import pytest
from huggingface_hub.utils import LocalEntryNotFoundError
from vllm.model_executor.model_loader.weight_utils import (
2026-01-19 10:38:50 +08:00
download_weights_from_hf,
enable_hf_transfer,
)
2026-01-09 13:34:11 +08:00
def test_hf_transfer_auto_activation():
if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
# in case it is already set, we can't test the auto activation
2026-01-19 10:38:50 +08:00
pytest.skip("HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation")
2026-01-09 13:34:11 +08:00
enable_hf_transfer()
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
2026-01-19 10:38:50 +08:00
HF_TRANSFER_ACTIVE = True
2026-01-09 13:34:11 +08:00
except ImportError:
2026-01-19 10:38:50 +08:00
HF_TRANSFER_ACTIVE = False
assert huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == HF_TRANSFER_ACTIVE
2026-01-09 13:34:11 +08:00
def test_download_weights_from_hf():
with tempfile.TemporaryDirectory() as tmpdir:
# assert LocalEntryNotFoundError error is thrown
# if offline is set and model is not cached
huggingface_hub.constants.HF_HUB_OFFLINE = True
with pytest.raises(LocalEntryNotFoundError):
2026-01-19 10:38:50 +08:00
download_weights_from_hf(
"facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir,
)
2026-01-09 13:34:11 +08:00
# download the model
huggingface_hub.constants.HF_HUB_OFFLINE = False
2026-01-19 10:38:50 +08:00
download_weights_from_hf(
"facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir,
)
2026-01-09 13:34:11 +08:00
# now it should work offline
huggingface_hub.constants.HF_HUB_OFFLINE = True
2026-01-19 10:38:50 +08:00
assert (
download_weights_from_hf(
"facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir,
)
is not None
)
2026-01-09 13:34:11 +08:00
if __name__ == "__main__":
test_hf_transfer_auto_activation()
test_download_weights_from_hf()