Sync from v0.13
This commit is contained in:
66
tests/model_executor/test_weight_utils.py
Normal file
66
tests/model_executor/test_weight_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import huggingface_hub.constants
|
||||
import pytest
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf,
|
||||
enable_hf_transfer,
|
||||
)
|
||||
|
||||
|
||||
def test_hf_transfer_auto_activation():
|
||||
if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
|
||||
# in case it is already set, we can't test the auto activation
|
||||
pytest.skip("HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation")
|
||||
enable_hf_transfer()
|
||||
try:
|
||||
# enable hf hub transfer if available
|
||||
import hf_transfer # type: ignore # noqa
|
||||
|
||||
HF_TRANSFER_ACTIVE = True
|
||||
except ImportError:
|
||||
HF_TRANSFER_ACTIVE = False
|
||||
assert huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == HF_TRANSFER_ACTIVE
|
||||
|
||||
|
||||
def test_download_weights_from_hf():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# assert LocalEntryNotFoundError error is thrown
|
||||
# if offline is set and model is not cached
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = True
|
||||
with pytest.raises(LocalEntryNotFoundError):
|
||||
download_weights_from_hf(
|
||||
"facebook/opt-125m",
|
||||
allow_patterns=["*.safetensors", "*.bin"],
|
||||
cache_dir=tmpdir,
|
||||
)
|
||||
|
||||
# download the model
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
||||
download_weights_from_hf(
|
||||
"facebook/opt-125m",
|
||||
allow_patterns=["*.safetensors", "*.bin"],
|
||||
cache_dir=tmpdir,
|
||||
)
|
||||
|
||||
# now it should work offline
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = True
|
||||
assert (
|
||||
download_weights_from_hf(
|
||||
"facebook/opt-125m",
|
||||
allow_patterns=["*.safetensors", "*.bin"],
|
||||
cache_dir=tmpdir,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hf_transfer_auto_activation()
|
||||
test_download_weights_from_hf()
|
||||
Reference in New Issue
Block a user