Spaces:
Runtime error
Runtime error
from typing import Optional, List, Tuple, Union, Literal, Dict | |
import torch | |
import torch.nn as nn | |
from transformers import ( | |
Qwen2Model, | |
Qwen2PreTrainedModel, | |
GenerationMixin, | |
) | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
from univa.models.modeling_univa_vision_tower import UnivaVisionTower | |
from univa.models.configuration_univa import UnivaConfig | |
from univa.models.modeling_univa_denoise_tower import UnivaDenoiseTower | |
class UnivaQwen2Model(Qwen2Model): | |
def __init__(self, config: UnivaConfig): | |
super().__init__(config) | |
self.config = config | |
class UnivaQwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): | |
config_class = UnivaConfig | |
def __init__(self, config: UnivaConfig): | |
super().__init__(config) | |
self.model = UnivaQwen2Model(config) | |
self.vision_tower = UnivaVisionTower(config.vision_tower) | |
self.denoise_tower = UnivaDenoiseTower(config.denoise_tower) | |
self.vocab_size = config.vocab_size | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.forward_denoiser = False | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_denoise_embeds( | |
self, | |
input_ids: torch.LongTensor, | |
pixel_values: Optional[List[torch.FloatTensor]] = None, | |
image_position: Optional[torch.LongTensor] = None, | |
): | |
input_embeds = self(input_ids, pixel_values, image_position)[0] | |
input_embeds = self.denoise_tower(input_embeds) | |
return input_embeds | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
pixel_values: Optional[List[torch.FloatTensor]] = None, | |
image_embeds: Optional[torch.FloatTensor] = None, | |
image_position: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
output_type: Literal["lvlm", "denoise_model_pred", "denoise_embeds"] = "lvlm", | |
only_use_t5: bool = False, | |
denoiser_kwargs: Optional[Dict] = {}, | |
**kwargs, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
if not only_use_t5: | |
if ( | |
self.forward_denoiser | |
): # Force forward denoiser, which is used in FSDP training | |
return self.denoise_tower.denoiser(**kwargs) | |
if "hidden_states" in kwargs: | |
print( | |
"You are using this model as a denoiser, please use the forward_denoiser_context to forward the model." | |
) | |
print("For example:") | |
print("with self.forward_denoiser_context():") | |
print(" ... # Your code ...") | |
inputs_embeds, shortcut_image_embeds = self.prepare_inputs_for_multimodal( | |
input_ids, | |
pixel_values, | |
image_position, | |
past_key_values, | |
output_image_embeds=True, | |
) | |
if output_type == "denoise_model_pred": | |
assert len(denoiser_kwargs) > 0, ( | |
"denoiser_kwargs should not be empty when output_type is denoise_model_pred" | |
) | |
return_dict = False | |
outputs = self.inner_forward( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
labels=labels, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
output_denoise_embeds=output_type.startswith("denoise"), | |
**kwargs, | |
) | |
else: | |
outputs = None | |
if output_type.startswith("denoise"): | |
if outputs is not None and shortcut_image_embeds is not None and self.config.shortcut_image_embeds: | |
for ( | |
batch_idx, | |
pos, | |
image_seq_length, | |
image_embeds_item, | |
) in shortcut_image_embeds: | |
outputs[batch_idx, pos : pos + image_seq_length, :] = ( | |
self.config.shortcut_image_embeds_scale * image_embeds_item | |
+ (1 - self.config.shortcut_image_embeds_scale) | |
* outputs[batch_idx, pos : pos + image_seq_length, :] | |
) | |
if output_type == "denoise_embeds": | |
# LVLM outputs -> MLP2 -> prompt_embeds | |
# with prompt_embeds, we can directly forward the denoiser. | |
return self.denoise_tower.denoise_projector(outputs) | |
elif output_type == "denoise_model_pred": | |
# LM outputs -> MLP2 -> Denoiser -> model_pred | |
return self.denoise_tower( | |
encoder_hidden_states=outputs, **denoiser_kwargs | |
) | |
else: | |
raise ValueError(f"Unknown output_type: {output_type}.") | |
return outputs | |
def prepare_inputs_for_multimodal( | |
self, | |
input_ids: torch.LongTensor, | |
pixel_values: Optional[List[torch.FloatTensor]] = None, | |
image_position: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
output_image_embeds: Optional[bool] = False, | |
) -> Tuple[torch.Tensor, Optional[List[Tuple[int, int, int, torch.Tensor]]]]: | |
batch_size, _ = input_ids.shape | |
input_embeds = self.model.embed_tokens(input_ids) | |
if ( | |
past_key_values is not None and len(past_key_values.key_cache) > 0 | |
): # Skip if using cache | |
return input_embeds, None | |
if pixel_values is None: # No image input | |
return input_embeds, None | |
image_embeds, shortcut_image_embeds_batch = self.vision_tower(pixel_values) | |
image_embeds = image_embeds.reshape(-1, image_embeds.shape[-1]) | |
if shortcut_image_embeds_batch is not None: | |
shortcut_image_embeds_batch = shortcut_image_embeds_batch.reshape(-1, image_embeds.shape[-1]) | |
n_image_tokens = (input_ids == self.config.image_token_id).sum().item() | |
n_image_features = image_embeds.shape[0] | |
if n_image_tokens != n_image_features: | |
raise ValueError( | |
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" | |
) | |
image_mask = ( | |
(input_ids == self.config.image_token_id) | |
.unsqueeze(-1) | |
.expand_as(input_embeds) | |
.to(input_embeds.device) | |
) | |
image_embeds = image_embeds.to(input_embeds.device, input_embeds.dtype) | |
input_embeds = input_embeds.masked_scatter(image_mask, image_embeds) | |
shortcut_image_embeds = [] | |
if pixel_values is not None and shortcut_image_embeds_batch is not None: | |
cum_image_len = 0 | |
for batch_idx in range(input_ids.shape[0]): | |
cur_input_ids = input_ids[batch_idx] | |
num_blocks, start_end_index, lengths = self.find_true_blocks((cur_input_ids == self.config.image_token_id)) | |
for i in range(len(num_blocks)): | |
shortcut_image_embeds.append( | |
( | |
# batch_idx, | |
# pos, | |
# lengths, | |
# shortcut_image_embeds_batch, | |
batch_idx, | |
start_end_index[i], | |
lengths[i], | |
shortcut_image_embeds_batch[cum_image_len: cum_image_len+lengths[i]], | |
) | |
) | |
cum_image_len = cum_image_len + lengths[i] | |
if output_image_embeds: | |
return input_embeds, shortcut_image_embeds | |
else: | |
return input_embeds, None | |
# def prepare_inputs_for_multimodal( | |
# self, | |
# input_ids: torch.LongTensor, | |
# pixel_values: Optional[List[torch.FloatTensor]] = None, | |
# image_position: Optional[torch.LongTensor] = None, | |
# past_key_values: Optional[List[torch.FloatTensor]] = None, | |
# output_image_embeds: Optional[bool] = False, | |
# ) -> Tuple[torch.Tensor, Optional[List[Tuple[int, int, int, torch.Tensor]]]]: | |
# batch_size, _ = input_ids.shape | |
# input_embeds = self.model.embed_tokens(input_ids) | |
# if ( | |
# past_key_values is not None and len(past_key_values.key_cache) > 0 | |
# ): # Skip if using cache | |
# return input_embeds, None | |
# if pixel_values is None: # No image input | |
# return input_embeds, None | |
# shortcut_image_embeds = [] | |
# for batch_idx in range(batch_size): | |
# images_batch = pixel_values[batch_idx] | |
# if len(images_batch) == 0: | |
# continue | |
# input_images = torch.stack(images_batch) | |
# image_embeds, shortcut_image_embeds_batch = self.vision_tower(input_images) | |
# for image_idx, pos in enumerate(image_position[batch_idx]): | |
# image_embeds_item = image_embeds[image_idx] | |
# image_seq_length, _ = image_embeds_item.shape | |
# assert ( | |
# input_ids[batch_idx, pos] | |
# == input_ids[batch_idx, pos + image_seq_length - 1] | |
# ), "image token is not correct" | |
# assert input_ids[batch_idx, pos - 1] == 151666, ( | |
# "image begin token is not correct" | |
# ) | |
# assert input_ids[batch_idx, pos + image_seq_length] == 151667, ( | |
# "image end token is not correct" | |
# ) | |
# input_embeds[batch_idx, pos : pos + image_seq_length, :] = ( | |
# image_embeds_item | |
# ) | |
# if shortcut_image_embeds_batch is not None: | |
# shortcut_image_embeds.append( | |
# ( | |
# batch_idx, | |
# pos, | |
# image_seq_length, | |
# shortcut_image_embeds_batch[image_idx], | |
# ) | |
# ) | |
# if output_image_embeds: | |
# return input_embeds, shortcut_image_embeds | |
# else: | |
# return input_embeds, None | |
def inner_forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
output_denoise_embeds: Optional[bool] = False, | |
**kwargs, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
**kwargs, | |
) | |
hidden_states = outputs[0] | |
if output_denoise_embeds: | |
return hidden_states | |
logits = self.lm_head(hidden_states) | |
logits = logits.float() | |
loss = None | |
if labels is not None: | |
loss = self.loss_function( | |
logits=logits, | |
labels=labels, | |
vocab_size=self.config.vocab_size, | |
**kwargs, | |
) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def forward_denoiser_context(self): | |
class ForwardDenoiserContext: | |
def __init__(self, model): | |
self.model = model | |
self.backup_config = None | |
def __enter__(self): | |
self.backup_config = self.model.config | |
self.model.config = self.model.denoise_tower.denoiser.config | |
self.model.forward_denoiser = True | |
return self.model | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.model.forward_denoiser = False | |
self.model.config = self.backup_config | |
return False | |
return ForwardDenoiserContext(self) | |