Sync from v0.13
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor import UniProcExecutor
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
# This is a dummy executor for patching in test_runai_model_streamer_s3.py.
|
||||
# We cannot use vllm_runner fixture here, because it spawns worker process.
|
||||
# The worker process reimports the patched entities, and the patch is not applied.
|
||||
class RunaiDummyExecutor(UniProcExecutor):
|
||||
def _init_executor(self) -> None:
|
||||
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
|
||||
|
||||
local_rank = 0
|
||||
rank = 0
|
||||
is_driver_worker = True
|
||||
|
||||
device_info = self.vllm_config.device_config.device.__str__().split(":")
|
||||
if len(device_info) > 1:
|
||||
local_rank = int(device_info[1])
|
||||
|
||||
worker_rpc_kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
wrapper_kwargs = {
|
||||
"vllm_config": self.vllm_config,
|
||||
}
|
||||
|
||||
self.driver_worker = WorkerWrapperBase(**wrapper_kwargs)
|
||||
|
||||
self.collective_rpc("init_worker", args=([worker_rpc_kwargs],))
|
||||
self.collective_rpc("init_device")
|
||||
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
|
||||
load_format = "runai_streamer"
|
||||
test_model = "openai-community/gpt2"
|
||||
# TODO(amacaskill): Replace with a GKE owned GCS bucket.
|
||||
test_gcs_model = "gs://vertex-model-garden-public-us/codegemma/codegemma-2b/"
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||
|
||||
|
||||
def get_runai_model_loader():
|
||||
load_config = LoadConfig(load_format=load_format)
|
||||
return get_model_loader(load_config)
|
||||
|
||||
|
||||
def test_get_model_loader_with_runai_flag():
|
||||
model_loader = get_runai_model_loader()
|
||||
assert model_loader.__class__.__name__ == "RunaiModelStreamerLoader"
|
||||
|
||||
|
||||
def test_runai_model_loader_download_files(vllm_runner):
|
||||
with vllm_runner(test_model, load_format=load_format) as llm:
|
||||
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||||
assert deserialized_outputs
|
||||
|
||||
|
||||
def test_runai_model_loader_download_files_gcs(
|
||||
vllm_runner, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
|
||||
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
|
||||
monkeypatch.setenv(
|
||||
"CLOUD_STORAGE_EMULATOR_ENDPOINT", "https://storage.googleapis.com"
|
||||
)
|
||||
with vllm_runner(test_gcs_model, load_format=load_format) as llm:
|
||||
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||||
assert deserialized_outputs
|
||||
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from runai_model_streamer.safetensors_streamer.streamer_mock import StreamerPatcher
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
from .conftest import RunaiDummyExecutor
|
||||
|
||||
load_format = "runai_streamer"
|
||||
test_model = "openai-community/gpt2"
|
||||
|
||||
|
||||
def test_runai_model_loader_download_files_s3_mocked_with_patch(
|
||||
vllm_runner,
|
||||
tmp_path: Path,
|
||||
monkeypatch,
|
||||
):
|
||||
patcher = StreamerPatcher(str(tmp_path))
|
||||
|
||||
test_mock_s3_model = "s3://my-mock-bucket/gpt2/"
|
||||
|
||||
# Download model from HF
|
||||
mock_model_dir = f"{tmp_path}/gpt2"
|
||||
snapshot_download(repo_id=test_model, local_dir=mock_model_dir)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"vllm.transformers_utils.runai_utils.runai_list_safetensors",
|
||||
patcher.shim_list_safetensors,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"vllm.transformers_utils.runai_utils.runai_pull_files",
|
||||
patcher.shim_pull_files,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"vllm.model_executor.model_loader.weight_utils.SafetensorsStreamer",
|
||||
patcher.create_mock_streamer,
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=test_mock_s3_model,
|
||||
load_format=load_format,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
executor = RunaiDummyExecutor(vllm_config)
|
||||
executor.driver_worker.load_model()
|
||||
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import huggingface_hub.constants
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf
|
||||
from vllm.transformers_utils.runai_utils import (
|
||||
ObjectStorageModel,
|
||||
is_runai_obj_uri,
|
||||
list_safetensors,
|
||||
)
|
||||
|
||||
|
||||
def test_is_runai_obj_uri():
|
||||
assert is_runai_obj_uri("gs://some-gcs-bucket/path")
|
||||
assert is_runai_obj_uri("s3://some-s3-bucket/path")
|
||||
assert not is_runai_obj_uri("nfs://some-nfs-path")
|
||||
|
||||
|
||||
def test_runai_list_safetensors_local():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
||||
download_weights_from_hf(
|
||||
"openai-community/gpt2",
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
cache_dir=tmpdir,
|
||||
)
|
||||
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
|
||||
assert len(safetensors) > 0
|
||||
parentdir = [os.path.dirname(safetensor) for safetensor in safetensors][0]
|
||||
files = list_safetensors(parentdir)
|
||||
assert len(safetensors) == len(files)
|
||||
|
||||
|
||||
def test_runai_pull_files_gcs(monkeypatch):
|
||||
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
|
||||
# Bypass default project lookup by setting GOOGLE_CLOUD_PROJECT
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
|
||||
filename = "LT08_L1GT_074061_20130309_20170505_01_T2_MTL.txt"
|
||||
gcs_bucket = "gs://gcp-public-data-landsat/LT08/01/074/061/LT08_L1GT_074061_20130309_20170505_01_T2/"
|
||||
gcs_url = f"{gcs_bucket}/{filename}"
|
||||
model = ObjectStorageModel(gcs_url)
|
||||
model.pull_files(gcs_bucket, allow_pattern=[f"*{filename}"])
|
||||
# To re-generate / change URLs:
|
||||
# gsutil ls -L gs://<gcs-url> | grep "Hash (md5)" | tr -d ' ' \
|
||||
# | cut -d":" -f2 | base64 -d | xxd -p
|
||||
expected_checksum = "f60dea775da1392434275b311b31a431"
|
||||
hasher = hashlib.new("md5")
|
||||
with open(os.path.join(model.dir, filename), "rb") as f:
|
||||
# Read the file in chunks to handle large files efficiently
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hasher.update(chunk)
|
||||
actual_checksum = hasher.hexdigest()
|
||||
assert actual_checksum == expected_checksum
|
||||
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import glob
|
||||
import tempfile
|
||||
|
||||
import huggingface_hub.constants
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
def test_runai_model_loader():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
||||
download_weights_from_hf(
|
||||
"openai-community/gpt2", allow_patterns=["*.safetensors"], cache_dir=tmpdir
|
||||
)
|
||||
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
|
||||
assert len(safetensors) > 0
|
||||
|
||||
runai_model_streamer_tensors = {}
|
||||
hf_safetensors_tensors = {}
|
||||
|
||||
for name, tensor in runai_safetensors_weights_iterator(safetensors, True):
|
||||
runai_model_streamer_tensors[name] = tensor
|
||||
|
||||
for name, tensor in safetensors_weights_iterator(safetensors, True):
|
||||
hf_safetensors_tensors[name] = tensor
|
||||
|
||||
assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)
|
||||
|
||||
for name, runai_tensor in runai_model_streamer_tensors.items():
|
||||
assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype
|
||||
assert runai_tensor.shape == hf_safetensors_tensors[name].shape
|
||||
assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_runai_model_loader()
|
||||
Reference in New Issue
Block a user