51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from .base import get_vllm_public_assets
|
|
|
|
VLM_IMAGES_DIR = "vision_model_images"
|
|
|
|
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato",
|
|
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
|
|
"Grayscale_8bits_palette_sample_image",
|
|
"1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300",
|
|
"231-200x300", "27-500x500", "17-150x600",
|
|
"handelsblatt-preview", "paper-11"]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ImageAsset:
|
|
name: ImageAssetName
|
|
|
|
def get_path(self, ext: str) -> Path:
|
|
"""
|
|
Return s3 path for given image.
|
|
"""
|
|
return get_vllm_public_assets(filename=f"{self.name}.{ext}",
|
|
s3_prefix=VLM_IMAGES_DIR)
|
|
|
|
@property
|
|
def pil_image(self, ext="jpg") -> Image.Image:
|
|
|
|
image_path = self.get_path(ext)
|
|
return Image.open(image_path)
|
|
|
|
@property
|
|
def image_embeds(self) -> torch.Tensor:
|
|
"""
|
|
Image embeddings, only used for testing purposes with llava 1.5.
|
|
"""
|
|
image_path = self.get_path('pt')
|
|
return torch.load(image_path, map_location="cpu", weights_only=True)
|
|
|
|
def read_bytes(self, ext: str) -> bytes:
|
|
p = Path(self.get_path(ext))
|
|
return p.read_bytes()
|