from dataclasses import dataclass import torch import torch.nn as nn @dataclass class ForgeMicroVisionConfig: image_size: int = 224 patch_size: int = 16 vision_width: int = 128 text_hidden_size: int = 1024 num_prefix_tokens: int = 16 class ForgeMicroVisionAdapter(nn.Module): """Tiny untrained image-to-prefix adapter scaffold for Forge-1V experiments.""" def __init__(self, config: ForgeMicroVisionConfig = ForgeMicroVisionConfig()): super().__init__() self.config = config self.patch_embed = nn.Conv2d( 3, config.vision_width, kernel_size=config.patch_size, stride=config.patch_size, bias=False, ) self.pool = nn.AdaptiveAvgPool1d(config.num_prefix_tokens) self.norm = nn.LayerNorm(config.vision_width) self.proj = nn.Linear(config.vision_width, config.text_hidden_size, bias=False) def forward(self, images: torch.Tensor) -> torch.Tensor: patches = self.patch_embed(images) tokens = patches.flatten(2) tokens = self.pool(tokens).transpose(1, 2) tokens = self.norm(tokens) return self.proj(tokens)