39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
@dataclass
|
|
class ForgeMicroVisionConfig:
|
|
image_size: int = 224
|
|
patch_size: int = 16
|
|
vision_width: int = 128
|
|
text_hidden_size: int = 1024
|
|
num_prefix_tokens: int = 16
|
|
|
|
|
|
class ForgeMicroVisionAdapter(nn.Module):
|
|
"""Tiny untrained image-to-prefix adapter scaffold for Forge-1V experiments."""
|
|
|
|
def __init__(self, config: ForgeMicroVisionConfig = ForgeMicroVisionConfig()):
|
|
super().__init__()
|
|
self.config = config
|
|
self.patch_embed = nn.Conv2d(
|
|
3,
|
|
config.vision_width,
|
|
kernel_size=config.patch_size,
|
|
stride=config.patch_size,
|
|
bias=False,
|
|
)
|
|
self.pool = nn.AdaptiveAvgPool1d(config.num_prefix_tokens)
|
|
self.norm = nn.LayerNorm(config.vision_width)
|
|
self.proj = nn.Linear(config.vision_width, config.text_hidden_size, bias=False)
|
|
|
|
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
|
patches = self.patch_embed(images)
|
|
tokens = patches.flatten(2)
|
|
tokens = self.pool(tokens).transpose(1, 2)
|
|
tokens = self.norm(tokens)
|
|
return self.proj(tokens)
|