Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

View File

@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import shutil
import pytest
from huggingface_hub import snapshot_download
from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver
MODEL_NAME = "Qwen/Qwen3-0.6B"
LORA_NAME = "charent/self_cognition_Alice"
PA_NAME = "swapnilbp/llama_tweet_ptune"
@pytest.fixture(scope="module")
def adapter_cache(request, tmpdir_factory):
# Create dir that mimics the structure of the adapter cache
adapter_cache = tmpdir_factory.mktemp(request.module.__name__) / "adapter_cache"
return adapter_cache
@pytest.fixture(scope="module")
def qwen3_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def pa_files():
return snapshot_download(repo_id=PA_NAME)
@pytest.mark.asyncio
async def test_filesystem_resolver(adapter_cache, qwen3_lora_files):
model_files = adapter_cache / LORA_NAME
shutil.copytree(qwen3_lora_files, model_files)
fs_resolver = FilesystemResolver(adapter_cache)
assert fs_resolver is not None
lora_request = await fs_resolver.resolve_lora(MODEL_NAME, LORA_NAME)
assert lora_request is not None
assert lora_request.lora_name == LORA_NAME
assert lora_request.lora_path == os.path.join(adapter_cache, LORA_NAME)
@pytest.mark.asyncio
async def test_missing_adapter(adapter_cache):
fs_resolver = FilesystemResolver(adapter_cache)
assert fs_resolver is not None
missing_lora_request = await fs_resolver.resolve_lora(MODEL_NAME, "foobar")
assert missing_lora_request is None
@pytest.mark.asyncio
async def test_nonlora_adapter(adapter_cache, pa_files):
model_files = adapter_cache / PA_NAME
shutil.copytree(pa_files, model_files)
fs_resolver = FilesystemResolver(adapter_cache)
assert fs_resolver is not None
pa_request = await fs_resolver.resolve_lora(MODEL_NAME, PA_NAME)
assert pa_request is None

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def register_prithvi():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501

View File

@@ -0,0 +1,411 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import datetime
import os
import tempfile
import urllib.request
from collections.abc import Sequence
from typing import Any
import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
IOProcessorResponse,
)
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (
IOProcessor,
IOProcessorInput,
IOProcessorOutput,
)
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
logger = init_logger(__name__)
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99
DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5]
datamodule_config: DataModuleConfig = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size": 16,
"constant_scale": 0.0001,
"data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
"drop_last": True,
"no_data_replace": 0.0,
"no_label_replace": -1,
"num_workers": 8,
"test_transform": [
albumentations.Resize(
always_apply=False, height=448, interpolation=1, p=1, width=448
),
albumentations.pytorch.ToTensorV2(
transpose_mask=False, always_apply=True, p=1.0
),
],
}
def save_geotiff(image: torch.Tensor, meta: dict, out_format: str) -> str | bytes:
"""Save multi-band image in Geotiff file.
Args:
image: np.ndarray with shape (bands, height, width)
output_path: path where to save the image
meta: dict with meta info.
"""
if out_format == "path":
# create temp file
file_path = os.path.join(os.getcwd(), "prediction.tiff")
with rasterio.open(file_path, "w", **meta) as dest:
for i in range(image.shape[0]):
dest.write(image[i, :, :], i + 1)
return file_path
elif out_format == "b64_json":
with tempfile.NamedTemporaryFile() as tmpfile:
with rasterio.open(tmpfile.name, "w", **meta) as dest:
for i in range(image.shape[0]):
dest.write(image[i, :, :], i + 1)
file_data = tmpfile.read()
return base64.b64encode(file_data)
else:
raise ValueError("Unknown output format")
def _convert_np_uint8(float_image: torch.Tensor):
image = float_image.numpy() * 255.0
image = image.astype(dtype=np.uint8)
return image
def read_geotiff(
file_path: str | None = None,
path_type: str | None = None,
file_data: bytes | None = None,
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
"""Read all bands from *file_path* and return image + meta info.
Args:
file_path: path to image file.
Returns:
np.ndarray with shape (bands, height, width)
meta info dict
"""
if all([x is None for x in [file_path, path_type, file_data]]):
raise Exception("All input fields to read_geotiff are None")
write_to_file: bytes | None = None
path: str | None = None
if file_data is not None:
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(file_data)
# path = tmpfile.name
write_to_file = file_data
elif file_path is not None and path_type == "url":
resp = urllib.request.urlopen(file_path)
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(resp.read())
# path = tmpfile.name
write_to_file = resp.read()
elif file_path is not None and path_type == "path":
path = file_path
elif file_path is not None and path_type == "b64_json":
image_data = base64.b64decode(file_path)
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(image_data)
# path = tmpfile.name
write_to_file = image_data
else:
raise Exception("Wrong combination of parameters to read_geotiff")
with tempfile.NamedTemporaryFile() as tmpfile:
path_to_use = None
if write_to_file:
tmpfile.write(write_to_file)
path_to_use = tmpfile.name
elif path:
path_to_use = path
with rasterio.open(path_to_use) as src:
img = src.read()
meta = src.meta
try:
coords = src.lnglat()
except Exception:
# Cannot read coords
coords = None
return img, meta, coords
def load_image(
data: list[str],
path_type: str,
mean: list[float] | None = None,
std: list[float] | None = None,
indices: list[int] | None | None = None,
):
"""Build an input example by loading images in *file_paths*.
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the
images in *file_paths*.
std: list containing std values for each band in the
images in *file_paths*.
Returns:
np.array containing created example
list of meta info for each image in *file_paths*
"""
imgs = []
metas = []
temporal_coords = []
location_coords = []
for file in data:
# if isinstance(file, bytes):
# img, meta, coords = read_geotiff(file_data=file)
# else:
img, meta, coords = read_geotiff(file_path=file, path_type=path_type)
# Rescaling (don't normalize on nodata)
img = np.moveaxis(img, 0, -1) # channels last for rescaling
if indices is not None:
img = img[..., indices]
if mean is not None and std is not None:
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
imgs.append(img)
metas.append(meta)
if coords is not None:
location_coords.append(coords)
try:
match = re.search(r"(\d{7,8}T\d{6})", file)
if match:
year = int(match.group(1)[:4])
julian_day = match.group(1).split("T")[0][4:]
if len(julian_day) == 3:
julian_day = int(julian_day)
else:
julian_day = (
datetime.datetime.strptime(julian_day, "%m%d")
.timetuple()
.tm_yday
)
temporal_coords.append([year, julian_day])
except Exception:
logger.exception("Could not extract timestamp for %s", file)
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
imgs = np.expand_dims(imgs, axis=0) # add batch di
return imgs, temporal_coords, location_coords, metas
class PrithviMultimodalDataProcessor(IOProcessor):
indices = [0, 1, 2, 3, 4, 5]
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.datamodule = Sen1Floods11NonGeoDataModule(
data_root=datamodule_config["data_root"],
batch_size=datamodule_config["batch_size"],
num_workers=datamodule_config["num_workers"],
bands=datamodule_config["bands"],
drop_last=datamodule_config["drop_last"],
test_transform=datamodule_config["test_transform"],
)
self.img_size = 512
self.h1 = 1
self.w1 = 1
self.original_h = 512
self.original_w = 512
self.batch_size = 1
self.meta_data = None
self.requests_cache: dict[str, dict[str, Any]] = {}
self.indices = DEFAULT_INPUT_INDICES
def parse_request(self, request: Any) -> IOProcessorInput:
if type(request) is dict:
image_prompt = ImagePrompt(**request)
return image_prompt
if isinstance(request, IOProcessorRequest):
if not hasattr(request, "data"):
raise ValueError("missing 'data' field in OpenAIBaseModel Request")
request_data = request.data
if type(request_data) is dict:
return ImagePrompt(**request_data)
else:
raise ValueError("Unable to parse the request data")
raise ValueError("Unable to parse request")
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
)
def pre_process(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
image_data = dict(prompt)
if request_id:
self.requests_cache[request_id] = {
"out_format": image_data["out_data_format"],
}
input_data, temporal_coords, location_coords, meta_data = load_image(
data=[image_data["data"]],
indices=self.indices,
path_type=image_data["data_format"],
)
self.meta_data = meta_data[0]
if input_data.mean() > 1:
input_data = input_data / 10000 # Convert to range 0-1
self.original_h, self.original_w = input_data.shape[-2:]
pad_h = (self.img_size - (self.original_h % self.img_size)) % self.img_size
pad_w = (self.img_size - (self.original_w % self.img_size)) % self.img_size
input_data = np.pad(
input_data,
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
mode="reflect",
)
batch = torch.tensor(input_data)
windows = batch.unfold(3, self.img_size, self.img_size).unfold(
4, self.img_size, self.img_size
)
self.h1, self.w1 = windows.shape[3:5]
windows = rearrange(
windows,
"b c t h1 w1 h w -> (b h1 w1) c t h w",
h=self.img_size,
w=self.img_size,
)
# Split into batches if number of windows > batch_size
num_batches = (
windows.shape[0] // self.batch_size
if windows.shape[0] > self.batch_size
else 1
)
windows = torch.tensor_split(windows, num_batches, dim=0)
if temporal_coords:
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
else:
temporal_coords = None
if location_coords:
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
else:
location_coords = None
prompts = []
for window in windows:
# Apply standardization
window = self.datamodule.test_transform(
image=window.squeeze().numpy().transpose(1, 2, 0)
)
window = self.datamodule.aug(window)["image"]
prompts.append(
{
"prompt_token_ids": [1],
"multi_modal_data": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
},
}
)
return prompts
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
pred_imgs_list = []
if request_id and (request_id in self.requests_cache):
out_format = self.requests_cache[request_id]["out_format"]
else:
out_format = "b64_json"
for output in model_output:
y_hat = output.outputs.data.argmax(dim=0)
pred = torch.nn.functional.interpolate(
y_hat[None, None, ...].float(),
size=self.img_size,
mode="nearest",
)
pred_imgs_list.append(pred)
pred_imgs: torch.Tensor = torch.concat(pred_imgs_list, dim=0)
# Build images from patches
pred_imgs = rearrange(
pred_imgs,
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
h=self.img_size,
w=self.img_size,
b=1,
c=1,
h1=self.h1,
w1=self.w1,
)
# Cut padded area back to original size
pred_imgs = pred_imgs[..., : self.original_h, : self.original_w]
# Squeeze (batch size 1)
pred_imgs = pred_imgs[0]
if not self.meta_data:
raise ValueError("No metadata available for the current task")
self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
out_data = save_geotiff(
_convert_np_uint8(pred_imgs), self.meta_data, out_format
)
return ImageRequestOutput(
type=out_format, format="tiff", data=out_data, request_id=request_id
)

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal, TypedDict
import albumentations
from pydantic import BaseModel
class DataModuleConfig(TypedDict):
bands: list[str]
batch_size: int
constant_scale: float
data_root: str
drop_last: bool
no_data_replace: float
no_label_replace: int
num_workers: int
test_transform: list[albumentations.core.transforms_interface.BasicTransform]
class ImagePrompt(BaseModel):
data_format: Literal["b64_json", "bytes", "url", "path"]
"""
This is the data type for the input image
"""
image_format: str
"""
This is the image format (e.g., jpeg, png, etc.)
"""
out_data_format: Literal["b64_json", "url"]
data: Any
"""
Input image data
"""
MultiModalPromptType = ImagePrompt
class ImageRequestOutput(BaseModel):
"""
The output data of an image request to vLLM.
Args:
type (str): The data content type [path, object]
format (str): The image format (e.g., jpeg, png, etc.)
data (Any): The resulting data.
"""
type: Literal["path", "b64_json"]
format: str
data: str
request_id: str | None = None

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from setuptools import setup
setup(
name="prithvi_io_processor_plugin",
version="0.1",
packages=["prithvi_io_processor"],
entry_points={
"vllm.io_processor_plugins": [
"prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501
]
},
)

View File

@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from setuptools import setup
setup(
name="vllm_add_dummy_model",
version="0.1",
packages=["vllm_add_dummy_model"],
entry_points={
"vllm.general_plugins": ["register_dummy_model = vllm_add_dummy_model:register"]
},
)

View File

@@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import ModelRegistry
def register():
# Test directly passing the model
from .my_opt import MyOPTForCausalLM
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
# Test passing lazy model
if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model(
"MyGemma2Embedding",
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding",
)
if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava", "vllm_add_dummy_model.my_llava:MyLlava")

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.sequence import IntermediateTensors
class MyGemma2Embedding(nn.Module):
is_pooling_model = True
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids,
positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if isinstance(hidden_states, IntermediateTensors):
return hidden_states
# Return all-zero embeddings
return torch.zeros_like(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = (
(name, data) for name, data in weights if not name.startswith("lm_head.")
)
return self.model.load_weights(weights)

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.models.llava import (
LlavaDummyInputsBuilder,
LlavaForConditionalGeneration,
LlavaMultiModalProcessor,
LlavaProcessingInfo,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_processor(
LlavaMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder,
)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits

View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.models.opt import OPTForCausalLM
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from setuptools import setup
setup(
name="vllm_add_dummy_platform",
version="0.1",
packages=["vllm_add_dummy_platform"],
entry_points={
"vllm.platform_plugins": [
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
],
"vllm.general_plugins": [
"dummy_custom_ops = vllm_add_dummy_platform:register_ops"
],
},
)

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def dummy_platform_plugin() -> str | None:
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
def register_ops():
import vllm_add_dummy_platform.dummy_custom_ops # noqa

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend
class DummyAttentionBackend(PlaceholderAttentionBackend):
@staticmethod
def get_name() -> str:
return "Dummy_Backend"

View File

@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
# Register CustomRotaryEmbedding to CustomOP.
@RotaryEmbedding.register_oot
class DummyRotaryEmbedding(RotaryEmbedding):
"""Original rotary positional embedding."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.addition_config = True
def forward_oot(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
return super().forward_oot(*args, **kwargs)

View File

@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm.platforms.interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
class DummyPlatform(Platform):
_enum = PlatformEnum.OOT
device_name = "DummyDevice"
device_type: str = "privateuseone"
dispatch_key: str = "PrivateUse1"
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
vllm_config.compilation_config.custom_ops = ["all"]
def get_attn_backend_cls(
self,
backend_name,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.metrics.loggers import StatLoggerBase
class DummyStatLogger(StatLoggerBase):
"""
A dummy stat logger for testing purposes.
Implements the minimal interface expected by StatLoggerManager.
"""
def __init__(self, vllm_config, engine_idx=0):
self.vllm_config = vllm_config
self.engine_idx = engine_idx
self.recorded = []
self.logged = False
self.engine_initialized = False
def record(self, scheduler_stats, iteration_stats, mm_cache_stats, engine_idx):
self.recorded.append(
(scheduler_stats, iteration_stats, mm_cache_stats, engine_idx)
)
def log(self):
self.logged = True
def log_engine_initialized(self):
self.engine_initialized = True

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from setuptools import setup
setup(
name="dummy_stat_logger",
version="0.1",
packages=["dummy_stat_logger"],
entry_points={
"vllm.stat_logger_plugins": [
"dummy_stat_logger = dummy_stat_logger.dummy_stat_logger:DummyStatLogger" # noqa
]
},
)