Spaces:
Running
Running
from typing import Optional, Tuple, Union | |
import torch | |
from transformers.modeling_outputs import BaseModelOutputWithPooling | |
from transformers.models.clip.configuration_clip import CLIPConfig | |
from transformers.models.clip.modeling_clip import CLIPModel, CLIPTextTransformer, _make_causal_mask, _expand_mask, clip_loss, CLIPOutput | |
class CLIPTextTransformerCanReceiveEmbed(CLIPTextTransformer): | |
def forward(self, | |
input_ids: Optional[torch.Tensor] = None, | |
input_embeds: Optional[torch.Tensor] = None, # NOTE | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None,) -> Union[Tuple, BaseModelOutputWithPooling]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_embeds is None: | |
if input_ids is None: | |
raise ValueError("You have to specify input_ids") | |
input_shape = input_ids.size() | |
input_ids = input_ids.view(-1, input_shape[-1]) | |
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) | |
else: | |
hidden_states = input_embeds | |
input_shape = torch.Size([hidden_states.size(0), hidden_states.size(1)]) | |
# CLIP's text model uses causal mask, prepare it here. | |
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 | |
# print(input_shape) | |
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) | |
# expand attention_mask | |
if attention_mask is not None: | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) | |
encoder_outputs = self.encoder( | |
inputs_embeds=hidden_states, | |
attention_mask=attention_mask, | |
causal_attention_mask=causal_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
last_hidden_state = encoder_outputs[0] | |
last_hidden_state = self.final_layer_norm(last_hidden_state) | |
# text_embeds.shape = [batch_size, sequence_length, transformer.width] | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
# eot embedding pos: input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1) | |
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 | |
if input_ids is not None: | |
eos_embedding_pos = input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1) | |
# print(input_ids, eos_embedding_pos) | |
else: | |
# pass | |
# TODO: is there any exception? | |
eos_embedding_pos = torch.tensor([input_embeds.size(1) - 1] * input_embeds.size(0), device=last_hidden_state.device) | |
pooled_output = last_hidden_state[ | |
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), | |
eos_embedding_pos | |
] | |
if not return_dict: | |
return (last_hidden_state, pooled_output) + encoder_outputs[1:] | |
return BaseModelOutputWithPooling( | |
last_hidden_state=last_hidden_state, | |
pooler_output=pooled_output, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
) | |
class CLIPModelCanReceiveTextEmbeds(CLIPModel): | |
def __init__(self, config: CLIPConfig): | |
super().__init__(config) | |
self.text_model = CLIPTextTransformerCanReceiveEmbed(config.text_config) | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
input_embeds: Optional[torch.LongTensor] = None, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
return_loss: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
only_return_logits_per_text = False, | |
no_grad_text = False | |
) -> Union[Tuple, CLIPOutput]: | |
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
if no_grad_text: | |
with torch.no_grad(): | |
text_outputs = self.text_model( | |
input_ids=input_ids, | |
input_embeds=input_embeds, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
else: | |
text_outputs = self.text_model( | |
input_ids=input_ids, | |
input_embeds=input_embeds, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
image_embeds = vision_outputs[1] | |
image_embeds = self.visual_projection(image_embeds) | |
text_embeds = text_outputs[1] | |
text_embeds = self.text_projection(text_embeds) | |
# normalized features | |
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | |
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale | |
logits_per_image = logits_per_text.t() | |
if only_return_logits_per_text: | |
return logits_per_text | |
loss = None | |
if return_loss: | |
loss = clip_loss(logits_per_text) | |
if not return_dict: | |
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) | |
return ((loss,) + output) if loss is not None else output | |
return CLIPOutput( | |
loss=loss, | |
logits_per_image=logits_per_image, | |
logits_per_text=logits_per_text, | |
text_embeds=text_embeds, | |
image_embeds=image_embeds, | |
text_model_output=text_outputs, | |
vision_model_output=vision_outputs, | |
) |