66 lines
2.5 KiB
Python
66 lines
2.5 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
"""Configuration for ColModernVBERT visual document retrieval model.
|
||
|
|
|
||
|
|
ColModernVBERT combines SigLIP vision encoder + ModernBERT text encoder
|
||
|
|
with a pixel shuffle connector and ColBERT-style 128-dim per-token embeddings.
|
||
|
|
|
||
|
|
Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
|
||
|
|
"""
|
||
|
|
|
||
|
|
from transformers import ModernBertConfig, PretrainedConfig, SiglipVisionConfig
|
||
|
|
|
||
|
|
|
||
|
|
class ColModernVBertConfig(PretrainedConfig):
|
||
|
|
model_type = "colmodernvbert"
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
embedding_dim: int = 128,
|
||
|
|
vlm_config: dict | None = None,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
super().__init__(**kwargs)
|
||
|
|
self.embedding_dim = embedding_dim
|
||
|
|
|
||
|
|
if vlm_config is None:
|
||
|
|
vlm_config = {}
|
||
|
|
|
||
|
|
# Top-level VLM fields
|
||
|
|
self.image_token_id = vlm_config.get("image_token_id", 50407)
|
||
|
|
self.pixel_shuffle_factor = vlm_config.get("pixel_shuffle_factor", 4)
|
||
|
|
self.hidden_size = vlm_config.get("hidden_size", 768)
|
||
|
|
additional_vocab_size = vlm_config.get("additional_vocab_size", 40)
|
||
|
|
|
||
|
|
# Text config (ModernBERT)
|
||
|
|
text_cfg = vlm_config.get("text_config", {})
|
||
|
|
base_vocab = text_cfg.get("vocab_size", 50368)
|
||
|
|
self.text_config = ModernBertConfig(
|
||
|
|
vocab_size=base_vocab + additional_vocab_size,
|
||
|
|
hidden_size=text_cfg.get("hidden_size", 768),
|
||
|
|
intermediate_size=text_cfg.get("intermediate_size", 1152),
|
||
|
|
num_hidden_layers=text_cfg.get("num_hidden_layers", 22),
|
||
|
|
num_attention_heads=text_cfg.get("num_attention_heads", 12),
|
||
|
|
mlp_bias=text_cfg.get("mlp_bias", False),
|
||
|
|
max_position_embeddings=vlm_config.get("max_position_embeddings", 8192),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Vision config (SigLIP)
|
||
|
|
vis_cfg = vlm_config.get("vision_config", {})
|
||
|
|
self.vision_config = SiglipVisionConfig(
|
||
|
|
hidden_size=vis_cfg.get("embed_dim", 768),
|
||
|
|
image_size=vis_cfg.get("image_size", 512),
|
||
|
|
patch_size=vis_cfg.get("patch_size", 16),
|
||
|
|
num_hidden_layers=vis_cfg.get("num_hidden_layers", 12),
|
||
|
|
intermediate_size=vis_cfg.get("intermediate_size", 3072),
|
||
|
|
num_attention_heads=vis_cfg.get("num_attention_heads", 12),
|
||
|
|
)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def image_seq_len(self) -> int:
|
||
|
|
ps = self.vision_config.image_size // self.vision_config.patch_size
|
||
|
|
return (ps * ps) // (self.pixel_shuffle_factor**2)
|
||
|
|
|
||
|
|
def get_text_config(self, **kwargs):
|
||
|
|
return self.text_config
|