Sync from v0.13
This commit is contained in:
95
vllm/transformers_utils/s3_utils.py
Normal file
95
vllm/transformers_utils/s3_utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import fnmatch
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.client import BaseClient
|
||||
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
boto3 = PlaceholderModule("boto3") # type: ignore[assignment]
|
||||
|
||||
|
||||
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path
|
||||
for path in paths
|
||||
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path
|
||||
for path in paths
|
||||
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def glob(
|
||||
s3: Optional["BaseClient"] = None,
|
||||
path: str = "",
|
||||
allow_pattern: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
List full file names from S3 path and filter by allow pattern.
|
||||
|
||||
Args:
|
||||
s3: S3 client to use.
|
||||
path: The S3 path to list from.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
|
||||
Returns:
|
||||
list[str]: List of full S3 paths allowed by the pattern
|
||||
"""
|
||||
if s3 is None:
|
||||
s3 = boto3.client("s3")
|
||||
if not path.endswith("/"):
|
||||
path = path + "/"
|
||||
bucket_name, _, paths = list_files(s3, path=path, allow_pattern=allow_pattern)
|
||||
return [f"s3://{bucket_name}/{path}" for path in paths]
|
||||
|
||||
|
||||
def list_files(
|
||||
s3: "BaseClient",
|
||||
path: str,
|
||||
allow_pattern: list[str] | None = None,
|
||||
ignore_pattern: list[str] | None = None,
|
||||
) -> tuple[str, str, list[str]]:
|
||||
"""
|
||||
List files from S3 path and filter by pattern.
|
||||
|
||||
Args:
|
||||
s3: S3 client to use.
|
||||
path: The S3 path to list from.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
Returns:
|
||||
tuple[str, str, list[str]]: A tuple where:
|
||||
- The first element is the bucket name
|
||||
- The second element is string represent the bucket
|
||||
and the prefix as a dir like string
|
||||
- The third element is a list of files allowed or
|
||||
disallowed by pattern
|
||||
"""
|
||||
parts = path.removeprefix("s3://").split("/")
|
||||
prefix = "/".join(parts[1:])
|
||||
bucket_name = parts[0]
|
||||
|
||||
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
||||
paths = [obj["Key"] for obj in objects.get("Contents", [])]
|
||||
|
||||
paths = _filter_ignore(paths, ["*/"])
|
||||
if allow_pattern is not None:
|
||||
paths = _filter_allow(paths, allow_pattern)
|
||||
|
||||
if ignore_pattern is not None:
|
||||
paths = _filter_ignore(paths, ignore_pattern)
|
||||
|
||||
return bucket_name, prefix, paths
|
||||
Reference in New Issue
Block a user