sasha HF Staff commited on
Commit
f3791c9
·
0 Parent(s):

Duplicate from SDbiaseval/find-my-butterfly

Browse files
Files changed (11) hide show
  1. .gitattributes +34 -0
  2. .gitignore +1 -0
  3. README.md +14 -0
  4. app.py +71 -0
  5. elton.jpg +0 -0
  6. gaga.jpg +0 -0
  7. index_768_cosine.pickle +3 -0
  8. ken.jpg +0 -0
  9. requirements.txt +7 -0
  10. similarity_utils.py +175 -0
  11. taylor.jpg +0 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio_cached_examples/
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Find My Butterfly 🦋
3
+ emoji: 🦋
4
+ colorFrom: yellow
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: SDbiaseval/find-my-butterfly
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import gradio as gr
3
+ from datasets import load_dataset
4
+ from transformers import AutoModel, AutoFeatureExtractor
5
+ import wikipedia
6
+
7
+
8
+ # Only runs once when the script is first run.
9
+ with open("index_768_cosine.pickle", "rb") as handle:
10
+ index = pickle.load(handle)
11
+
12
+ # Load model for computing embeddings.
13
+ feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
14
+ model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
15
+
16
+ # Candidate images.
17
+ dataset = load_dataset("sasha/butterflies_10k_names_multiple")
18
+ ds = dataset["train"]
19
+
20
+
21
+ def query(image, top_k=4):
22
+ inputs = feature_extractor(image, return_tensors="pt")
23
+ model_output = model(**inputs)
24
+ embedding = model_output.pooler_output.detach()
25
+ results = index.query(embedding, k=top_k)
26
+ inx = results[0][0].tolist()
27
+ logits = results[1][0].tolist()
28
+ images = ds.select(inx)["image"]
29
+ captions = ds.select(inx)["name"]
30
+ images_with_captions = [(i, c) for i, c in zip(images,captions)]
31
+ labels_with_probs = dict(zip(captions,logits))
32
+ labels_with_probs = {k: 1- v for k, v in labels_with_probs.items()}
33
+ try:
34
+ description = wikipedia.summary(captions[0], sentences = 1)
35
+ description = "### " + description
36
+ url = wikipedia.page(captions[0]).url
37
+ url = " You can learn more about your butterfly [here](" + str(url) + ")!"
38
+ description = description + url
39
+ except:
40
+ description = "### Butterflies are insects in the order Lepidoptera, which also includes moths. Adult butterflies have large, often brightly coloured wings."
41
+ url = "https://en.wikipedia.org/wiki/Butterfly"
42
+ url = " You can learn more about butterflies [here](" + str(url) + ")!"
43
+ description = description + url
44
+ return images_with_captions, labels_with_probs, description
45
+
46
+
47
+ with gr.Blocks() as demo:
48
+ gr.Markdown("# Find my Butterfly 🦋")
49
+ gr.Markdown("## Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co/datasets/huggan/inat_butterflies_top10k)!")
50
+ with gr.Row():
51
+ with gr.Column(min_width= 900):
52
+ inputs = gr.Image(shape=(800, 1600))
53
+ btn = gr.Button("Find my butterfly!")
54
+ description = gr.Markdown()
55
+
56
+ with gr.Column():
57
+ outputs=gr.Gallery().style(grid=[2], height="auto")
58
+ labels = gr.Label()
59
+
60
+ gr.Markdown("### Image Examples")
61
+ gr.Examples(
62
+ examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
63
+ inputs=inputs,
64
+ outputs=[outputs,labels],
65
+ fn=query,
66
+ cache_examples=True,
67
+ )
68
+ btn.click(query, inputs, [outputs, labels, description])
69
+
70
+ demo.launch()
71
+
elton.jpg ADDED
gaga.jpg ADDED
index_768_cosine.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:864fe29de71f0e5b56ca87b04d559ea707d4cd3798429f80f04b8a58a07e3721
3
+ size 53168791
ken.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.25.1
2
+ datasets==2.7.1
3
+ numpy==1.21.6
4
+ torch==1.12.1
5
+ torchvision
6
+ pynndescent
7
+ wikipedia
similarity_utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import datasets
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+ from tqdm.auto import tqdm
9
+ from transformers import AutoFeatureExtractor, AutoModel
10
+
11
+ seed = 42
12
+ hash_size = 8
13
+ hidden_dim = 768 # ViT-base
14
+ np.random.seed(seed)
15
+
16
+
17
+ # Device.
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # Load model for computing embeddings..
21
+ model_ckpt = "nateraw/vit-base-beans"
22
+ extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
23
+
24
+ # Data transformation chain.
25
+ transformation_chain = T.Compose(
26
+ [
27
+ # We first resize the input image to 256x256 and then we take center crop.
28
+ T.Resize(int((256 / 224) * extractor.size["height"])),
29
+ T.CenterCrop(extractor.size["height"]),
30
+ T.ToTensor(),
31
+ T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
32
+ ]
33
+ )
34
+
35
+
36
+ # Define random vectors to project with.
37
+ random_vectors = np.random.randn(hash_size, hidden_dim).T
38
+
39
+
40
+ def hash_func(embedding, random_vectors=random_vectors):
41
+ """Randomly projects the embeddings and then computes bit-wise hashes."""
42
+ if not isinstance(embedding, np.ndarray):
43
+ embedding = np.array(embedding)
44
+ if len(embedding.shape) < 2:
45
+ embedding = np.expand_dims(embedding, 0)
46
+
47
+ # Random projection.
48
+ bools = np.dot(embedding, random_vectors) > 0
49
+ return [bool2int(bool_vec) for bool_vec in bools]
50
+
51
+
52
+ def bool2int(x):
53
+ y = 0
54
+ for i, j in enumerate(x):
55
+ if j:
56
+ y += 1 << i
57
+ return y
58
+
59
+
60
+ def compute_hash(model: Union[torch.nn.Module, str]):
61
+ """Computes hash on a given dataset."""
62
+ device = model.device
63
+
64
+ def pp(example_batch):
65
+ # Prepare the input images for the model.
66
+ image_batch = example_batch["image"]
67
+ image_batch_transformed = torch.stack(
68
+ [transformation_chain(image) for image in image_batch]
69
+ )
70
+ new_batch = {"pixel_values": image_batch_transformed.to(device)}
71
+
72
+ # Compute embeddings and pool them i.e., take the representations from the [CLS]
73
+ # token.
74
+ with torch.no_grad():
75
+ embeddings = model(**new_batch).last_hidden_state[:, 0].cpu().numpy()
76
+
77
+ # Compute hashes for the batch of images.
78
+ hashes = [hash_func(embeddings[i]) for i in range(len(embeddings))]
79
+ example_batch["hashes"] = hashes
80
+ return example_batch
81
+
82
+ return pp
83
+
84
+
85
+ class Table:
86
+ def __init__(self, hash_size: int):
87
+ self.table = {}
88
+ self.hash_size = hash_size
89
+
90
+ def add(self, id: int, hashes: List[int], label: int):
91
+ # Create a unique indentifier.
92
+ entry = {"id_label": str(id) + "_" + str(label)}
93
+
94
+ # Add the hash values to the current table.
95
+ for h in hashes:
96
+ if h in self.table:
97
+ self.table[h].append(entry)
98
+ else:
99
+ self.table[h] = [entry]
100
+
101
+ def query(self, hashes: List[int]):
102
+ results = []
103
+
104
+ # Loop over the query hashes and determine if they exist in
105
+ # the current table.
106
+ for h in hashes:
107
+ if h in self.table:
108
+ results.extend(self.table[h])
109
+ return results
110
+
111
+
112
+ class LSH:
113
+ def __init__(self, hash_size, num_tables):
114
+ self.num_tables = num_tables
115
+ self.tables = []
116
+ for i in range(self.num_tables):
117
+ self.tables.append(Table(hash_size))
118
+
119
+ def add(self, id: int, hash: List[int], label: int):
120
+ for table in self.tables:
121
+ table.add(id, hash, label)
122
+
123
+ def query(self, hashes: List[int]):
124
+ results = []
125
+ for table in self.tables:
126
+ results.extend(table.query(hashes))
127
+ return results
128
+
129
+
130
+ class BuildLSHTable:
131
+ def __init__(
132
+ self,
133
+ model: Union[torch.nn.Module, None],
134
+ batch_size: int = 48,
135
+ hash_size: int = hash_size,
136
+ dim: int = hidden_dim,
137
+ num_tables: int = 10,
138
+ ):
139
+ self.hash_size = hash_size
140
+ self.dim = dim
141
+ self.num_tables = num_tables
142
+ self.lsh = LSH(self.hash_size, self.num_tables)
143
+
144
+ self.batch_size = batch_size
145
+ self.hash_fn = compute_hash(model.to(device))
146
+
147
+ def build(self, ds: datasets.DatasetDict):
148
+ dataset_hashed = ds.map(self.hash_fn, batched=True, batch_size=self.batch_size)
149
+
150
+ for id in tqdm(range(len(dataset_hashed))):
151
+ hash, label = dataset_hashed[id]["hashes"], dataset_hashed[id]["labels"]
152
+ self.lsh.add(id, hash, label)
153
+
154
+ def query(self, image, verbose=True):
155
+ if isinstance(image, str):
156
+ image = Image.open(image).convert("RGB")
157
+
158
+ # Compute the hashes of the query image and fetch the results.
159
+ example_batch = dict(image=[image])
160
+ hashes = self.hash_fn(example_batch)["hashes"][0]
161
+
162
+ results = self.lsh.query(hashes)
163
+ if verbose:
164
+ print("Matches:", len(results))
165
+
166
+ # Calculate Jaccard index to quantify the similarity.
167
+ counts = {}
168
+ for r in results:
169
+ if r["id_label"] in counts:
170
+ counts[r["id_label"]] += 1
171
+ else:
172
+ counts[r["id_label"]] = 1
173
+ for k in counts:
174
+ counts[k] = float(counts[k]) / self.dim
175
+ return counts
taylor.jpg ADDED