Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -7,6 +7,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import math
from abc import ABC
from collections.abc import Iterable
@@ -18,6 +19,8 @@ from transformers import AutoModel, PretrainedConfig
from transformers.image_processing_utils_fast import BaseImageProcessorFast
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.internvl import (
@@ -30,24 +33,29 @@ from vllm.model_executor.models.internvl import (
InternVLProcessor,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.siglip import SiglipVisionModel
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_image_processor_from_config
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
from .interfaces import (
MultiModalEmbeddings,
SupportsCrossEncoding,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<image>"
from .interfaces_base import VllmModelForPooling
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
def build_transform(input_size: int):
@@ -183,10 +191,12 @@ def image_to_pixel_values_nemotron_vl(
min_num: int,
max_num: int,
use_thumbnail: bool,
transform: T.Compose | None = None,
) -> torch.Tensor:
target_ratios = get_nemotron_vl_target_ratios(min_num, max_num)
transform = build_transform(input_size=input_size)
if transform is None:
transform = build_transform(input_size=input_size)
images = dynamic_preprocess_nemotron_vl(
image,
@@ -200,11 +210,15 @@ def image_to_pixel_values_nemotron_vl(
class NemotronVLProcessor(InternVLProcessor):
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<image>"
def __init__(
self,
config: PretrainedConfig,
tokenizer: TokenizerLike,
image_processor: BaseImageProcessorFast,
image_processor: BaseImageProcessorFast | None = None,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
@@ -236,11 +250,18 @@ class NemotronVLProcessor(InternVLProcessor):
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail: bool = self.image_processor.use_thumbnail
if image_processor is not None:
self.use_thumbnail = image_processor.use_thumbnail
else:
self.use_thumbnail = getattr(config, "use_thumbnail", True)
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
return self.tokenizer.get_vocab()[self.IMG_CONTEXT]
def _get_transform(self) -> T.Compose:
return build_transform(input_size=self.image_size)
def get_num_image_tokens(
self,
@@ -283,10 +304,26 @@ class NemotronVLProcessor(InternVLProcessor):
min_num=min_num,
max_num=max_num,
use_thumbnail=self.use_thumbnail,
transform=self._get_transform(),
)
for image in images
]
def _replace_image_tokens(
self,
text: list[str],
pixel_values_lst: list[torch.Tensor],
) -> list[str]:
"""Replace <image> placeholders with image tokens."""
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
# Use temporary placeholder to avoid replacing tokens we just inserted
NVL_IMAGE_CONTEXT = image_repl.full.replace("<image>", "<NVL_IMG_CONTEXT>")
text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text]
return [t.replace("<NVL_IMG_CONTEXT>", self.IMG_CONTEXT) for t in text]
def _preprocess_image(
self,
text: list[str],
@@ -311,15 +348,7 @@ class NemotronVLProcessor(InternVLProcessor):
),
}
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
NVL_IMAGE_CONTEXT = image_repl.full.replace(
"<image>", "<NVL_IMG_CONTEXT>"
)
text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text]
text = [t.replace("<NVL_IMG_CONTEXT>", IMG_CONTEXT) for t in text]
text = self._replace_image_tokens(text, pixel_values_lst)
return text, image_inputs
def get_image_repl(
@@ -327,10 +356,10 @@ class NemotronVLProcessor(InternVLProcessor):
feature_size: int,
num_patches: int | None,
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
repl_features = self.IMG_CONTEXT * feature_size
repl_full = self.IMG_START + repl_features + self.IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
return PromptUpdateDetails.select_text(repl_full, self.IMG_CONTEXT)
class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
@@ -396,7 +425,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
hf_config=config.get_text_config(),
prefix=maybe_prefix(prefix, "language_model"),
)
@@ -413,7 +442,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
text_config = config.get_text_config()
llm_quant_config = getattr(text_config, "quantization_config", None)
if (not quant_config.modules_to_not_convert) and (
llm_quant_config is not None
@@ -429,10 +458,17 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
):
return AutoModel.from_config(config.vision_config, trust_remote_code=True)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vit_hidden_size
vision_projection_hidden_size = config.projector_hidden_size
llm_hidden_size = config.text_config.hidden_size
def _init_mlp1(
self,
config: PretrainedConfig,
vit_hidden_size: int | None = None,
vision_projection_hidden_size: int | None = None,
) -> nn.Module:
if vit_hidden_size is None:
vit_hidden_size = config.vit_hidden_size
if vision_projection_hidden_size is None:
vision_projection_hidden_size = config.projector_hidden_size
llm_hidden_size = config.get_text_config().hidden_size
return nn.Sequential(
nn.LayerNorm(
@@ -465,10 +501,18 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
x = x.permute(0, 2, 1, 3).contiguous()
return x
def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Call vision model and return embeddings.
Override this method in subclasses to handle different vision model
interfaces (e.g., SigLIP vs C-RADIO).
"""
vit_embeds = self.vision_model(x=pixel_values).features
return vit_embeds.to(dtype=torch.bfloat16)
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
# https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/modeling.py#L177
vit_embeds = self.vision_model(x=pixel_values).features
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
vit_embeds = self._call_vision_model(pixel_values)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
@@ -523,15 +567,16 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"]
hidden_size = self.config.get_text_config().hidden_size
# Only one image in the current batch
if len(num_patches) == 1:
return (image_embeds.view(-1, self.config.text_config.hidden_size),)
return (image_embeds.view(-1, hidden_size),)
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
image_embeds = image_embeds.view(-1, hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in num_patches
]
@@ -643,3 +688,255 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
connector="mlp1",
tower_model="vision_model",
)
# --------------------------------------------------------
# LlamaNemotronVL Embedding Model (nvidia/llama-nemotron-embed-vl-1b-v2)
# Extends LlamaNemotronVLChatModel for embedding/pooling tasks:
# - SigLIP vision encoder (instead of C-RADIO)
# - Bidirectional (non-causal) LLaMA language model
# - Pooler output instead of generative logits
# --------------------------------------------------------
# SigLIP normalization constants
SIGLIP_MEAN = (0.5, 0.5, 0.5)
SIGLIP_STD = (0.5, 0.5, 0.5)
def build_siglip_transform(input_size: int):
"""Build transform for SigLIP vision encoder with normalization.
Extends the base transform from nemotron_vl with SigLIP-specific normalization.
"""
base_transform = build_transform(input_size=input_size)
return T.Compose(
[
base_transform,
T.Normalize(mean=SIGLIP_MEAN, std=SIGLIP_STD),
]
)
class LlamaNemotronVLEmbedProcessor(NemotronVLProcessor):
"""
Processor for LlamaNemotronVL embedding model.
Inherits from NemotronVLProcessor and specializes it for embedding tasks:
- Uses SigLIP transform with normalization instead of base transform
- Uses different image context token (<IMG_CONTEXT> vs <image>)
"""
IMG_CONTEXT = "<IMG_CONTEXT>"
def __init__(
self,
config: PretrainedConfig,
tokenizer: TokenizerLike,
processor_config: dict,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
) -> None:
if min_dynamic_patch is None:
min_dynamic_patch = processor_config.get(
"min_input_tiles",
getattr(config, "min_dynamic_patch", 1),
)
if max_dynamic_patch is None:
max_dynamic_patch = processor_config.get(
"max_input_tiles",
getattr(config, "max_dynamic_patch", 1),
)
if dynamic_image_size is None:
dynamic_image_size = processor_config.get(
"dynamic_image_size",
getattr(config, "dynamic_image_size", True),
)
super().__init__(
config=config,
tokenizer=tokenizer,
image_processor=None,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
def _get_transform(self) -> T.Compose:
"""Override to add SigLIP normalization."""
return build_siglip_transform(input_size=self.image_size)
def _replace_image_tokens(
self,
text: list[str],
pixel_values_lst: list[torch.Tensor],
) -> list[str]:
"""Override with simpler token replacement for embedding model.
No temporary placeholder needed because IMG_CONTEXT is <IMG_CONTEXT>,
not <image>, so there's no collision risk.
"""
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
text = [t.replace("<image>", image_repl.full, 1) for t in text]
return text
class LlamaNemotronVLEmbedProcessingInfo(NemotronVLProcessingInfo):
"""Processing info for LlamaNemotronVL embedding model."""
def get_hf_processor(self, **kwargs: object) -> LlamaNemotronVLEmbedProcessor:
"""Override to create embedding-specific processor without image_processor."""
model_config = self.ctx.model_config
processor_config = {}
if model_config.model is not None:
processor_config = (
get_hf_file_to_dict(
"processor_config.json",
model_config.model,
model_config.revision,
)
or {}
)
return self.ctx.init_processor(
LlamaNemotronVLEmbedProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
processor_config=processor_config,
**kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
BaseInternVLMultiModalProcessor[LlamaNemotronVLEmbedProcessingInfo],
info=LlamaNemotronVLEmbedProcessingInfo,
dummy_inputs=BaseInternVLDummyInputsBuilder[LlamaNemotronVLEmbedProcessingInfo],
)
class LlamaNemotronVLForEmbedding(LlamaNemotronVLChatModel, VllmModelForPooling):
"""
LlamaNemotronVL model for embeddings.
Inherits from LlamaNemotronVLChatModel and specializes it for embedding tasks:
- Uses SigLIP vision encoder instead of C-RADIO
- Uses bidirectional LLaMA (via llm_config) instead of causal LLaMA
- Adds pooler for embedding output instead of generating logits
"""
is_pooling_model = True
# Weight mapping from checkpoint format to vLLM format
# Different from parent class due to different vision model structure
weight_mapper = WeightsMapper(
orig_to_new_prefix={
# Language model mapping
"language_model.layers.": "language_model.model.layers.",
"language_model.embed_tokens.": "language_model.model.embed_tokens.",
"language_model.norm.": "language_model.model.norm.",
# Vision model mapping (SiglipVisionModel has nested vision_model)
"vision_model.encoder.": "vision_model.vision_model.encoder.",
"vision_model.embeddings.": "vision_model.vision_model.embeddings.",
"vision_model.post_layernorm.": "vision_model.vision_model.post_layernorm.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
# Override: get img_context_token_id from config (parent sets None)
self.img_context_token_id = getattr(config, "img_context_token_id", None)
# Initialize pooler for embedding output
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_embedding(pooler_config)
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config,
*,
prefix: str,
) -> nn.Module:
"""Override to use SigLIP instead of C-RADIO."""
return SiglipVisionModel(
config.vision_config,
quant_config=quant_config,
prefix=prefix,
use_head=False,
)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
"""Override to use different MLP structure for embedding model."""
return super()._init_mlp1(
config,
vit_hidden_size=config.vision_config.hidden_size,
vision_projection_hidden_size=config.get_text_config().hidden_size,
)
def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Override to handle SigLIP interface."""
return self.vision_model(pixel_values)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Override to use different weight mapping for SigLIP."""
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.weight_mapper)
class LlamaNemotronVLForSequenceClassification(
LlamaNemotronVLForEmbedding, SupportsCrossEncoding
):
"""LlamaNemotronVL model variant for sequence classification / reranking."""
# Reranker checkpoint places base model weights under `model.*`,
# while `score.*` remains at the top level.
weight_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) | (
LlamaNemotronVLForEmbedding.weight_mapper
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
text_config = vllm_config.model_config.hf_config.get_text_config()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
self.score = ReplicatedLinear(
model_config.get_hidden_size(),
text_config.num_labels,
bias=False,
params_dtype=model_config.head_dtype,
quant_config=quant_config,
return_bias=False,
prefix=maybe_prefix(prefix, "score"),
)
pooler_config = model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_weights = super().load_weights(weights)
# reranker checkpoint omits the inner LM seq-cls head
# (`language_model.score.*`). It is unused by this outer model, but
# the default loader expects all parameters to be initialized.
for name, param in self.named_parameters():
if not name.startswith("language_model.score.") or name in loaded_weights:
continue
if name.endswith(".weight"):
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
elif name.endswith(".bias"):
torch.nn.init.zeros_(param)
else:
torch.nn.init.normal_(param, mean=0.0, std=0.02)
loaded_weights.add(name)
return loaded_weights