UniWorld-V1 / univa /models /modeling_univa_vision_tower.py
LinB203
init
0c8d55e
from univa.models.configuration_univa_vision_tower import UnivaVisionTowerConfig
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from transformers.modeling_utils import PreTrainedModel
import torch
import torch.nn as nn
class UnivaVisionTower(PreTrainedModel):
config_class = UnivaVisionTowerConfig
base_model_prefix = "model"
def __init__(self, config: UnivaVisionTowerConfig):
super().__init__(config)
self.config = config
# Initialize vision tower
if config.vision_tower_type == "siglip":
self.model = SiglipVisionModel._from_config(self.config.vision_tower_config)
self.mm_hidden_size = self.config.vision_tower_config.hidden_size
else:
raise ValueError(f"Unknown vision tower type: {config.vision_tower_type}")
self._init_mm_projector()
if self.config.shortcut_projector_type is not None:
self._init_shortcut_projector()
def _init_mm_projector(self):
"""Initialize the mm_projector for multi model input."""
if self.config.mm_projector_type == "mlp2x_gelu":
self.mm_projector = nn.Sequential(
nn.Linear(
self.mm_hidden_size,
self.config.output_hidden_size,
),
nn.GELU(),
nn.Linear(
self.config.output_hidden_size, self.config.output_hidden_size
),
)
else:
raise ValueError(
f"Unknown mm_projector_type: {self.config.mm_projector_type}"
)
def _init_shortcut_projector(self):
"""Initialize the shortcut_projector for multi model input."""
if self.config.shortcut_projector_type == "mlp2x_gelu":
self.shortcut_projector = nn.Sequential(
nn.Linear(
self.mm_hidden_size,
self.config.output_hidden_size,
),
nn.GELU(),
nn.Linear(
self.config.output_hidden_size, self.config.output_hidden_size
),
)
elif self.config.shortcut_projector_type == "share_mm_projector":
...
else:
raise ValueError(
f"Unknown shortcut_projector_type: {self.config.shortcut_projector_type}"
)
def forward(self, pixel_values: torch.Tensor):
x = self.model(pixel_values, output_hidden_states=True)
x = x.hidden_states[self.config.feature_select_layer]
if (
self.config.shortcut_projector_type is not None
and self.config.shortcut_projector_type != "share_mm_projector"
):
shortcut_x = self.shortcut_projector(x)
x = self.mm_projector(x)
if self.config.shortcut_projector_type is not None:
if (
self.config.shortcut_projector_type == "share_mm_projector"
): # Share the mm_projector as the shortcut_projector
return x, x
else:
return x, shortcut_x
else:
return x, None