File size: 4,603 Bytes
a380d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
942cbdf
 
 
 
 
 
 
 
 
 
 
 
 
241e5ca
942cbdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a380d32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#Acknowledgments:
#This project is inspired by:
#1. https://github.com/haltakov/natural-language-image-search by Vladimir Haltakov
#2. OpenAI's CLIP



#Importing all the necessary libraries
import torch
import requests
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO

from PIL import Image as PILIMAGE

from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from sentence_transformers import SentenceTransformer, util



device = "cuda" if torch.cuda.is_available() else "cpu"

# Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

# Load data
photos = pd.read_csv("./photos.tsv000", sep='\t', header=0)
photo_features = np.load("./features.npy")
photo_ids = pd.read_csv("./photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])



def encode_text(text):
  with torch.no_grad():
    # Encode and normalize the description using CLIP
    inputs = tokenizer([text],  padding=True, return_tensors="pt")
    inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
  text_encoded =  model.get_text_features(**inputs).detach().numpy()
  return text_encoded

def encode_image(image):
  image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
  with torch.no_grad():
        photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
        search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
        search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
  image_encoded = search_photo_feature.cpu().numpy()
  return image_encoded

T2I = "Text2Image"
I2I = "Image2Image"

def similarity(feature, photo_features):
  similarities = list((feature @ photo_features.T).squeeze(0))
  return similarities

def find_best_matches(image, mode, text):
  # Compute the similarity between the descrption and each photo using the Cosine similarity
  print ("Mode now ",mode) 

  if mode == "Text2Image":
    # Encode the text input
    text_features = encode_text(text)
    feature = text_features
    similarities = similarity(text_features, photo_features)
    
    
  else:
    #Encode the image input
    image_features = encode_image(image)
    feature = image_features
    similarities = similarity(image_features, photo_features)
    
  # Sort the photos by their similarity score
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)
  
  matched_images = []
  for i in range(3):
    # Retrieve the photo ID
    idx = best_photos[i][1]
    photo_id = photo_ids[idx]
    
    # Get all metadata for this photo
    photo_data = photos[photos["photo_id"] == photo_id].iloc[0]
    
    # Display the images
    #display(Image(url=photo_data["photo_image_url"] + "?w=640"))
    response = requests.get(photo_data["photo_image_url"] + "?w=640")
    img = PILIMAGE.open(BytesIO(response.content))
    matched_images.append(img)
  return matched_images


iface = gr.Interface(
    fn=find_best_matches,
    inputs=[
        gr.Image(label="Image to search"),
        gr.Radio([T2I, I2I], label="Search Mode"),
        gr.Textbox(lines=1, label="Text Query", placeholder="Introduce the search text...")
    ],
    theme="grass",
    outputs=[
        gr.Gallery(
            label="Generated images",
            show_label=False,
            elem_id="gallery"
        ).scale(grid=[2], height="auto")
    ],
    enable_queue=True,
    title="CLIP Image Search",
    description="This application displays TOP THREE images from Unsplash dataset that best match the search query provided by the user. Moreover, the input can be provided via two modes: text or image form."
)

iface.launch()

# gr.Interface(fn=find_best_matches,
#                      inputs=[
#             gr.Image(label="Image to search"),
#             gr.Radio([T2I, I2I]),
#             gr.Textbox(lines=1, label="Text query", placeholder="Introduce the search text...",
#             )],
#             theme="grass",
#             outputs=[gr.Gallery(
#             label="Generated images", show_label=False, elem_id="gallery"
#         ).style(grid=[2], height="auto")], enable_queue=True, title="CLIP Image Search",
#         description="This application displays TOP THREE images from Unsplash dataset that best match the search query provided by the user. Moreover, the input can be provided via two modes ie text or image form.").launch()