UniWorld-V1 / univa /models /modeling_univa.py
LinB203
init
0c8d55e
from typing import Optional, List, Tuple, Union, Literal, Dict
import torch
import torch.nn as nn
from transformers import (
Qwen2Model,
Qwen2PreTrainedModel,
GenerationMixin,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from univa.models.modeling_univa_vision_tower import UnivaVisionTower
from univa.models.configuration_univa import UnivaConfig
from univa.models.modeling_univa_denoise_tower import UnivaDenoiseTower
class UnivaQwen2Model(Qwen2Model):
def __init__(self, config: UnivaConfig):
super().__init__(config)
self.config = config
class UnivaQwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
config_class = UnivaConfig
def __init__(self, config: UnivaConfig):
super().__init__(config)
self.model = UnivaQwen2Model(config)
self.vision_tower = UnivaVisionTower(config.vision_tower)
self.denoise_tower = UnivaDenoiseTower(config.denoise_tower)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.forward_denoiser = False
# Initialize weights and apply final processing
self.post_init()
def get_denoise_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: Optional[List[torch.FloatTensor]] = None,
image_position: Optional[torch.LongTensor] = None,
):
input_embeds = self(input_ids, pixel_values, image_position)[0]
input_embeds = self.denoise_tower(input_embeds)
return input_embeds
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[List[torch.FloatTensor]] = None,
image_embeds: Optional[torch.FloatTensor] = None,
image_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_type: Literal["lvlm", "denoise_model_pred", "denoise_embeds"] = "lvlm",
only_use_t5: bool = False,
denoiser_kwargs: Optional[Dict] = {},
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if not only_use_t5:
if (
self.forward_denoiser
): # Force forward denoiser, which is used in FSDP training
return self.denoise_tower.denoiser(**kwargs)
if "hidden_states" in kwargs:
print(
"You are using this model as a denoiser, please use the forward_denoiser_context to forward the model."
)
print("For example:")
print("with self.forward_denoiser_context():")
print(" ... # Your code ...")
inputs_embeds, shortcut_image_embeds = self.prepare_inputs_for_multimodal(
input_ids,
pixel_values,
image_position,
past_key_values,
output_image_embeds=True,
)
if output_type == "denoise_model_pred":
assert len(denoiser_kwargs) > 0, (
"denoiser_kwargs should not be empty when output_type is denoise_model_pred"
)
return_dict = False
outputs = self.inner_forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
output_denoise_embeds=output_type.startswith("denoise"),
**kwargs,
)
else:
outputs = None
if output_type.startswith("denoise"):
if outputs is not None and shortcut_image_embeds is not None and self.config.shortcut_image_embeds:
for (
batch_idx,
pos,
image_seq_length,
image_embeds_item,
) in shortcut_image_embeds:
outputs[batch_idx, pos : pos + image_seq_length, :] = (
self.config.shortcut_image_embeds_scale * image_embeds_item
+ (1 - self.config.shortcut_image_embeds_scale)
* outputs[batch_idx, pos : pos + image_seq_length, :]
)
if output_type == "denoise_embeds":
# LVLM outputs -> MLP2 -> prompt_embeds
# with prompt_embeds, we can directly forward the denoiser.
return self.denoise_tower.denoise_projector(outputs)
elif output_type == "denoise_model_pred":
# LM outputs -> MLP2 -> Denoiser -> model_pred
return self.denoise_tower(
encoder_hidden_states=outputs, **denoiser_kwargs
)
else:
raise ValueError(f"Unknown output_type: {output_type}.")
return outputs
def prepare_inputs_for_multimodal(
self,
input_ids: torch.LongTensor,
pixel_values: Optional[List[torch.FloatTensor]] = None,
image_position: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
output_image_embeds: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[List[Tuple[int, int, int, torch.Tensor]]]]:
batch_size, _ = input_ids.shape
input_embeds = self.model.embed_tokens(input_ids)
if (
past_key_values is not None and len(past_key_values.key_cache) > 0
): # Skip if using cache
return input_embeds, None
if pixel_values is None: # No image input
return input_embeds, None
image_embeds, shortcut_image_embeds_batch = self.vision_tower(pixel_values)
image_embeds = image_embeds.reshape(-1, image_embeds.shape[-1])
if shortcut_image_embeds_batch is not None:
shortcut_image_embeds_batch = shortcut_image_embeds_batch.reshape(-1, image_embeds.shape[-1])
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(input_embeds)
.to(input_embeds.device)
)
image_embeds = image_embeds.to(input_embeds.device, input_embeds.dtype)
input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
shortcut_image_embeds = []
if pixel_values is not None and shortcut_image_embeds_batch is not None:
cum_image_len = 0
for batch_idx in range(input_ids.shape[0]):
cur_input_ids = input_ids[batch_idx]
num_blocks, start_end_index, lengths = self.find_true_blocks((cur_input_ids == self.config.image_token_id))
for i in range(len(num_blocks)):
shortcut_image_embeds.append(
(
# batch_idx,
# pos,
# lengths,
# shortcut_image_embeds_batch,
batch_idx,
start_end_index[i],
lengths[i],
shortcut_image_embeds_batch[cum_image_len: cum_image_len+lengths[i]],
)
)
cum_image_len = cum_image_len + lengths[i]
if output_image_embeds:
return input_embeds, shortcut_image_embeds
else:
return input_embeds, None
# def prepare_inputs_for_multimodal(
# self,
# input_ids: torch.LongTensor,
# pixel_values: Optional[List[torch.FloatTensor]] = None,
# image_position: Optional[torch.LongTensor] = None,
# past_key_values: Optional[List[torch.FloatTensor]] = None,
# output_image_embeds: Optional[bool] = False,
# ) -> Tuple[torch.Tensor, Optional[List[Tuple[int, int, int, torch.Tensor]]]]:
# batch_size, _ = input_ids.shape
# input_embeds = self.model.embed_tokens(input_ids)
# if (
# past_key_values is not None and len(past_key_values.key_cache) > 0
# ): # Skip if using cache
# return input_embeds, None
# if pixel_values is None: # No image input
# return input_embeds, None
# shortcut_image_embeds = []
# for batch_idx in range(batch_size):
# images_batch = pixel_values[batch_idx]
# if len(images_batch) == 0:
# continue
# input_images = torch.stack(images_batch)
# image_embeds, shortcut_image_embeds_batch = self.vision_tower(input_images)
# for image_idx, pos in enumerate(image_position[batch_idx]):
# image_embeds_item = image_embeds[image_idx]
# image_seq_length, _ = image_embeds_item.shape
# assert (
# input_ids[batch_idx, pos]
# == input_ids[batch_idx, pos + image_seq_length - 1]
# ), "image token is not correct"
# assert input_ids[batch_idx, pos - 1] == 151666, (
# "image begin token is not correct"
# )
# assert input_ids[batch_idx, pos + image_seq_length] == 151667, (
# "image end token is not correct"
# )
# input_embeds[batch_idx, pos : pos + image_seq_length, :] = (
# image_embeds_item
# )
# if shortcut_image_embeds_batch is not None:
# shortcut_image_embeds.append(
# (
# batch_idx,
# pos,
# image_seq_length,
# shortcut_image_embeds_batch[image_idx],
# )
# )
# if output_image_embeds:
# return input_embeds, shortcut_image_embeds
# else:
# return input_embeds, None
def inner_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_denoise_embeds: Optional[bool] = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
if output_denoise_embeds:
return hidden_states
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def forward_denoiser_context(self):
class ForwardDenoiserContext:
def __init__(self, model):
self.model = model
self.backup_config = None
def __enter__(self):
self.backup_config = self.model.config
self.model.config = self.model.denoise_tower.denoiser.config
self.model.forward_denoiser = True
return self.model
def __exit__(self, exc_type, exc_val, exc_tb):
self.model.forward_denoiser = False
self.model.config = self.backup_config
return False
return ForwardDenoiserContext(self)