Spaces:
Running
Running
import timm | |
from timm.models._factory import load_checkpoint | |
import torch | |
import os | |
from typing import List, Union | |
from torch import nn | |
from torch.jit import Final | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange | |
from utils.dl.common.model import get_model_device, set_module | |
import torch.nn.functional as F | |
from utils.common.log import logger | |
from transformers import CLIPProcessor, CLIPModel, CLIPVisionConfig, CLIPConfig | |
from dnns.clip.custom_clip import CLIPModelCanReceiveTextEmbeds | |
import torch.nn.functional as F | |
class Clip_ViTB16(nn.Module): | |
def __init__(self, img_size): | |
super(Clip_ViTB16, self).__init__() | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
self.model: CLIPModel = CLIPModelCanReceiveTextEmbeds.from_pretrained("openai/clip-vit-base-patch16") | |
self.img_size = img_size | |
# reconstruct xx | |
vm_embed = self.model.vision_model.embeddings | |
raw_num_patches = vm_embed.num_patches | |
vm_embed.num_patches = (img_size // self.model.vision_model.embeddings.patch_size) ** 2 | |
vm_embed.num_positions = vm_embed.num_patches + 1 | |
vm_embed.register_buffer("position_ids", torch.arange(vm_embed.num_positions).expand((1, -1)), persistent=False) | |
logger.info(f'due to changed input image size ({img_size}), num patches are updated from {raw_num_patches} to {vm_embed.num_patches}') | |
self.first_inference = True | |
def forward(self, images, texts: Union[List[List[str]], torch.Tensor], for_training, disable_return_loss=False, only_return_logits_per_text=False, no_grad_text=False): | |
if isinstance(texts[0], str): | |
inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True) | |
else: | |
# input embeds instead of input ids | |
# however, original CLIP cannot receive Tensor as input | |
inputs = self.processor(images=images, return_tensors="pt") | |
inputs['attention_mask'] = torch.ones((texts.size(0), texts.size(1))) | |
inputs['input_embeds'] = texts | |
if for_training and not disable_return_loss: | |
inputs['return_loss'] = True | |
else: | |
inputs['return_loss'] = False | |
inputs['only_return_logits_per_text'] = only_return_logits_per_text | |
inputs['no_grad_text'] = no_grad_text | |
for k, v in inputs.items(): | |
if isinstance(v, torch.Tensor): | |
inputs[k] = v.to('cuda') | |
if self.first_inference: | |
logger.info(f'before input size: {inputs["pixel_values"].size()}') | |
# print(inputs.keys()) | |
# print(inputs['pixel_values'].size()) | |
inputs['pixel_values'] = F.interpolate(inputs['pixel_values'], size=(self.img_size, self.img_size)) | |
# print(inputs['pixel_values'].size()) | |
if self.first_inference: | |
logger.info(f'after input size: {inputs["pixel_values"].size()}') | |
self.first_inference = False | |
return self.model(**inputs) | |
# @torch.no_grad() | |
# def clip_vit_b_16(): | |
# # https://huggingface.co/openai/clip-vit-base-patch16 | |
# model = CLIPModelCanReceiveTextEmbeds.from_pretrained("openai/clip-vit-base-patch16") | |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
# print(model) | |
# from PIL import Image | |
# import requests | |
# image = Image.open('/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/003.backpack/003_0001.jpg') | |
# inputs = processor(text=["a photo of a dog", "a photo of a backpack", "a photo of a cat"], images=image, return_tensors="pt", padding=True) | |
# print(inputs) | |
# from utils.dl.common.model import LayerActivation2, get_module | |
# input_embed_hook = LayerActivation2(get_module(model, 'text_model.embeddings')) | |
# outputs = model(**inputs) | |
# logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
# probs = logits_per_image.softmax(dim=1) | |
# print(probs) | |
# input_embed = input_embed_hook.output | |
# input_embed_hook.remove() | |
# torch.save(input_embed, os.path.join(os.path.dirname(__file__), './test_input_embed.pth')) | |
# print('embed', input_embed.size()) | |
# del inputs['input_ids'] | |
# inputs['input_embeds'] = input_embed | |
# outputs = model(**inputs) | |
# logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
# probs = logits_per_image.softmax(dim=1) | |
# print(probs) | |
def clip_vit_b_16(img_size): | |
# https://huggingface.co/openai/clip-vit-base-patch16 | |
return Clip_ViTB16(img_size) | |
if __name__ == '__main__': | |
model = clip_vit_b_16().cuda() | |
# print(model) | |
# exit() | |
# config = CLIPConfig.from_pretrained('openai/clip-vit-base-patch16') | |
# print(config) | |
# # test 1: single image inference | |
# from PIL import Image | |
# import requests | |
# image = Image.open('/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/003.backpack/003_0001.jpg') | |
# text = ["a photo of a dog", "a photo of a backpack", "a photo of a cat"] | |
# o = model(image, text, False) | |
# print(o) | |
# print(o.logits_per_image.softmax(dim=1)) | |
# o = model(image, torch.load('dnns/clip/test_input_embed.pth'), False) | |
# # print(o) | |
# print(o.logits_per_image.softmax(dim=1)) | |
# exit() | |
# test 2: normal training using clip loss (batch) | |
from data import get_dataset, build_dataloader | |
from torchvision.transforms import Compose, ToTensor, Resize | |
dataset = get_dataset('Caltech256', '/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/', 'train', transform=Compose([ | |
Resize((32, 32)), ToTensor() | |
])) | |
dataloader = build_dataloader(dataset, 8, 0, True, None) | |
from PIL import Image | |
import requests | |
images, labels = next(iter(dataloader)) | |
# torch.save(images, 'dnns/clip/test_image.pth') | |
classes = dataset.classes | |
text = [f"a photo of a {classes[i]}" for i in labels] # should be ground truth | |
print(text) | |
print(images.size()) | |
o = model(images, text, True) | |
print(o) | |
print(o.logits_per_image.softmax(dim=1)) | |
# o = model(image, torch.load('dnns/clip/test_input_embed.pth'), False) | |
# # print(o) | |
# print(o.logits_per_image.softmax(dim=1)) |