|
import gradio as gr |
|
import torch |
|
import faiss |
|
import numpy as np |
|
import pandas as pd |
|
import folium |
|
from PIL import Image |
|
from pathlib import Path |
|
import torchvision.transforms as tfm |
|
from torchvision.transforms import functional as F |
|
import logging |
|
import sys |
|
import io |
|
import base64 |
|
import time |
|
import math |
|
import random |
|
import ast |
|
|
|
from models.apl_model_dinov2 import DINOv2FeatureExtractor |
|
|
|
|
|
sys.path.append(str(Path("image-matching-models"))) |
|
sys.path.append(str(Path("image-matching-models/matching/third_party"))) |
|
|
|
import util_matching |
|
from matching import get_matcher |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
MODEL_CHECKPOINT_PATH = Path("weights/best_model_95.6.torch") |
|
FAISS_INDEX_PATH = Path("faiss_index/faiss_index_2021.bin") |
|
CSV_MAPPING_PATH = Path("faiss_index/faiss_index_to_local_path.csv") |
|
|
|
DEVICE = "cpu" |
|
MATCHING_IMG_SIZE = 512 |
|
logging.info(f"Using device: {DEVICE}") |
|
|
|
|
|
|
|
for path, desc in [ |
|
(MODEL_CHECKPOINT_PATH, "Model checkpoint"), |
|
(FAISS_INDEX_PATH, "FAISS index"), |
|
(CSV_MAPPING_PATH, "Path mapping CSV"), |
|
]: |
|
if not path.exists(): |
|
raise FileNotFoundError(f"{desc} not found at: {path}") |
|
MODEL_NAME = "xfeat_steerers" |
|
|
|
matcher = get_matcher(MODEL_NAME, device=DEVICE, max_num_keypoints=2048) |
|
|
|
|
|
if MODEL_NAME == 'xfeat_steerers': |
|
|
|
matcher.model.dev = DEVICE |
|
|
|
|
|
|
|
|
|
def pil_to_base64(image): |
|
"""Convert a PIL image to a base64-encoded string for HTML embedding.""" |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
return f"data:image/png;base64,{img_str}" |
|
|
|
|
|
|
|
def create_map( |
|
final_footprint, candidate_num, filename, inliers, query_footprint=None |
|
): |
|
""" |
|
Create and return a Folium map (as HTML string) showing the final footprint (blue) |
|
and optionally the query's ground truth footprint (orange). |
|
""" |
|
|
|
if final_footprint: |
|
center = final_footprint[0] |
|
zoom = 10 |
|
elif query_footprint: |
|
center = query_footprint[0] |
|
zoom = 10 |
|
else: |
|
center = [0, 0] |
|
zoom = 2 |
|
|
|
m = folium.Map(location=center, zoom_start=zoom) |
|
|
|
|
|
if query_footprint: |
|
folium.Polygon( |
|
locations=query_footprint, |
|
popup="Ground Truth Query Footprint", |
|
color="orange", |
|
fill=True, |
|
fill_color="orange", |
|
fill_opacity=0.4, |
|
).add_to(m) |
|
|
|
|
|
if final_footprint: |
|
|
|
footprint_text = "\n".join( |
|
[f"({lat:.4f}, {lon:.4f})" for lat, lon in final_footprint] |
|
) |
|
popup_text = f"Predicted Footprint:<br>{footprint_text}<br><br>Candidate: {candidate_num}<br>Inliers: {inliers}" |
|
|
|
|
|
folium.Polygon( |
|
locations=final_footprint, |
|
popup=popup_text, |
|
color="blue", |
|
fill=True, |
|
fill_color="cyan", |
|
fill_opacity=0.5, |
|
).add_to(m) |
|
|
|
|
|
folium.Marker( |
|
location=[final_footprint[0][0], final_footprint[0][1]], |
|
popup=f"Footprint Coordinates:<br>{footprint_text}", |
|
icon=folium.Icon(color="blue"), |
|
).add_to(m) |
|
elif not query_footprint: |
|
|
|
folium.Marker( |
|
location=[0, 0], |
|
popup="No valid location found.", |
|
icon=folium.Icon(color="red"), |
|
).add_to(m) |
|
|
|
return m._repr_html_() |
|
|
|
|
|
|
|
def get_surrounding_tiles(candidate_path): |
|
""" |
|
Given a candidate image path, extract its tile indices (zoom, row, column). |
|
Then, retrieve image files in the same directory with the same zoom level |
|
and specific row/col offsets. |
|
""" |
|
candidate_zoom, candidate_row, candidate_col = util_matching.get_tile_indices( |
|
candidate_path |
|
) |
|
folder = candidate_path.parent.parent |
|
files = [ |
|
p |
|
for p in folder.glob("**/*") |
|
if p.suffix.lower() in [".jpg", ".jpeg", ".png"] |
|
] |
|
|
|
tiles = [] |
|
row_offsets = [-4, 0, 4] |
|
col_offsets = [-4, 0, 4] |
|
desired_rows = {candidate_row + r for r in row_offsets} |
|
desired_cols = {candidate_col + c for c in col_offsets} |
|
|
|
for p in files: |
|
try: |
|
zoom, row, col = util_matching.get_tile_indices(p) |
|
if ( |
|
zoom == candidate_zoom |
|
and row in desired_rows |
|
and col in desired_cols |
|
): |
|
tiles.append((p, (row, col), None)) |
|
except Exception: |
|
continue |
|
|
|
tiles.sort(key=lambda t: (t[1][0], t[1][1])) |
|
return tiles |
|
|
|
|
|
def compose_tiles_ordered(tiles, tile_size, candidate_indices): |
|
""" |
|
Given a list of tiles, create a 3x3 grid image where the positions |
|
correspond to a step of 4 from the candidate's row/col. |
|
For any missing tile, insert a blank image. |
|
""" |
|
candidate_row, candidate_col = candidate_indices |
|
grid_img = Image.new("RGB", (tile_size[0] * 3, tile_size[1] * 3)) |
|
blank = Image.new("RGB", tile_size, color=(0, 0, 0)) |
|
|
|
tile_dict = {rc: p for (p, rc, _) in tiles} |
|
|
|
for i, row_offset in enumerate([-4, 0, 4]): |
|
for j, col_offset in enumerate([-4, 0, 4]): |
|
desired_row = candidate_row + row_offset |
|
desired_col = candidate_col + col_offset |
|
tile_path = tile_dict.get((desired_row, desired_col)) |
|
|
|
if tile_path: |
|
try: |
|
img = Image.open(tile_path).convert("RGB") |
|
except Exception as e: |
|
logging.error(f"Could not open tile {tile_path}: {e}") |
|
img = blank |
|
else: |
|
img = blank |
|
|
|
img = tfm.Resize(tile_size, antialias=True)(img) |
|
grid_img.paste(img, (j * tile_size[0], i * tile_size[1])) |
|
return grid_img |
|
|
|
|
|
|
|
def run_matching(query_image, candidate_image, base_footprint): |
|
""" |
|
Runs 4 iterations of matching and returns the best result. |
|
""" |
|
|
|
local_fm = None |
|
viz_params = None |
|
|
|
|
|
for iteration in range(4): |
|
( |
|
num_inliers, |
|
local_fm, |
|
predicted_footprint, |
|
pretty_footprint, |
|
) = util_matching.estimate_footprint( |
|
local_fm, |
|
query_image, |
|
candidate_image, |
|
matcher, |
|
base_footprint, |
|
HW=MATCHING_IMG_SIZE, |
|
viz_params=viz_params, |
|
) |
|
|
|
|
|
|
|
if num_inliers == -1 or num_inliers is None: |
|
return -1, [] |
|
|
|
|
|
if hasattr(predicted_footprint, "tolist"): |
|
best_footprint = predicted_footprint.tolist() |
|
else: |
|
best_footprint = predicted_footprint |
|
|
|
return num_inliers, best_footprint |
|
|
|
|
|
|
|
logging.info("Loading assets. This may take a moment...") |
|
|
|
|
|
try: |
|
model = DINOv2FeatureExtractor( |
|
model_type="vit_base_patch14_reg4_dinov2.lvd142m", |
|
num_of_layers_to_unfreeze=0, |
|
desc_dim=768, |
|
aggregator_type="SALAD", |
|
) |
|
logging.info(f"Loading model checkpoint from {MODEL_CHECKPOINT_PATH}...") |
|
model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE) |
|
model.load_state_dict(model_state_dict) |
|
model = model.to(DEVICE) |
|
model.eval() |
|
logging.info("DINOv2 model and checkpoint loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to load the model: {e}") |
|
raise |
|
|
|
|
|
faiss_index = faiss.read_index(str(FAISS_INDEX_PATH)) |
|
num_db_images = faiss_index.ntotal // 4 |
|
logging.info( |
|
f"FAISS index loaded. Contains {faiss_index.ntotal} vectors for {num_db_images} unique images." |
|
) |
|
|
|
|
|
try: |
|
mapping_df = pd.read_csv(CSV_MAPPING_PATH, index_col="faiss_index") |
|
logging.info(f"Path mapping loaded. Contains {len(mapping_df)} entries.") |
|
except Exception as e: |
|
logging.error(f"Failed to load path mapping CSV: {e}") |
|
raise |
|
|
|
|
|
image_transform = tfm.Compose( |
|
[ |
|
tfm.Resize((model.image_size, model.image_size), antialias=True), |
|
tfm.ToTensor(), |
|
tfm.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
logging.info("Assets loaded. Gradio app is ready.") |
|
|
|
|
|
|
|
def search_and_retrieve( |
|
query_image: Image.Image, query_footprint_str: str, num_results: int |
|
): |
|
""" |
|
Main function to search the database, run matching, and return results. |
|
This function is a generator to update the UI sequentially. |
|
""" |
|
progress = gr.Progress() |
|
query_footprint = None |
|
if query_footprint_str: |
|
try: |
|
query_footprint = ast.literal_eval(query_footprint_str) |
|
query_footprint = [list(coord) for coord in query_footprint] |
|
except (ValueError, SyntaxError): |
|
logging.warning("Could not parse query footprint string.") |
|
query_footprint = None |
|
|
|
if query_image is None: |
|
|
|
yield create_map(None, None, None, None), None |
|
return |
|
|
|
progress(0.1, desc="Preprocessing query") |
|
if query_image.mode == "RGBA": |
|
query_image = query_image.convert("RGB") |
|
|
|
image_tensor = image_transform(query_image).to(DEVICE) |
|
with torch.no_grad(): |
|
descriptor = model(image_tensor.unsqueeze(0)) |
|
descriptor_np = descriptor.cpu().numpy() |
|
|
|
k_neighbors = num_results |
|
progress(0.2, desc=f"Searching database for {k_neighbors} neighbors") |
|
distances, indices = faiss_index.search(descriptor_np, k_neighbors) |
|
flat_indices = indices.flatten() |
|
|
|
global_best_inliers = -1 |
|
global_best_footprint = None |
|
global_candidate_num = None |
|
global_filename = None |
|
global_best_display_image = None |
|
|
|
candidate_infos = [] |
|
processed_image_indices = set() |
|
|
|
query_tensor = tfm.ToTensor()(tfm.Resize((MATCHING_IMG_SIZE, MATCHING_IMG_SIZE), antialias=True)(query_image)) |
|
|
|
|
|
progress(0.4, desc="Processing candidates") |
|
for faiss_idx in flat_indices: |
|
|
|
|
|
|
|
image_index = faiss_idx % num_db_images |
|
best_rotation_index = faiss_idx // num_db_images |
|
|
|
if image_index in processed_image_indices: |
|
continue |
|
|
|
processed_image_indices.add(image_index) |
|
candidate_num = len(candidate_infos) + 1 |
|
|
|
candidate_path = Path(mapping_df.loc[image_index]["local_path"]) |
|
|
|
try: |
|
_, candidate_row, candidate_col = util_matching.get_tile_indices(candidate_path) |
|
except Exception: |
|
logging.warning( |
|
f"Skipping candidate {candidate_path.name} due to parsing error." |
|
) |
|
continue |
|
|
|
tiles = get_surrounding_tiles(candidate_path) |
|
composite_img = compose_tiles_ordered( |
|
tiles, (1024, 1024), (candidate_row, candidate_col) |
|
) |
|
|
|
display_img = F.rotate( |
|
composite_img, [0, 90, 180, 270][best_rotation_index] |
|
) |
|
|
|
candidate_img_tensor = tfm.ToTensor()(composite_img) |
|
candidate_img_tensor = tfm.Resize( |
|
(MATCHING_IMG_SIZE * 3, MATCHING_IMG_SIZE * 3), antialias=True |
|
)(candidate_img_tensor) |
|
candidate_img_tensor = F.rotate( |
|
candidate_img_tensor, [0, 90, 180, 270][best_rotation_index] |
|
) |
|
|
|
candidate_img_tensor = candidate_img_tensor.to(DEVICE) |
|
|
|
progress( |
|
0.5 + len(candidate_infos) / num_results * 0.4, |
|
desc=f"Running matching for candidate {candidate_num}", |
|
) |
|
base_footprint = util_matching.path_to_footprint(candidate_path) |
|
best_inliers, best_footprint = run_matching( |
|
query_tensor, candidate_img_tensor, base_footprint |
|
) |
|
|
|
if best_inliers > -1: |
|
candidate_infos.append( |
|
{ |
|
"candidate_num": candidate_num, |
|
"filename": candidate_path.name, |
|
"inliers": best_inliers, |
|
"display_image": display_img, |
|
"footprint": best_footprint, |
|
} |
|
) |
|
|
|
if best_inliers > global_best_inliers: |
|
global_best_inliers = best_inliers |
|
global_best_footprint = best_footprint |
|
global_candidate_num = candidate_num |
|
global_filename = candidate_path.name |
|
global_best_display_image = display_img |
|
|
|
progress(0.9, desc="Finalizing results") |
|
|
|
folium_map_html = create_map( |
|
global_best_footprint, |
|
global_candidate_num, |
|
global_filename, |
|
global_best_inliers, |
|
query_footprint=query_footprint, |
|
) |
|
progress(1, desc="Done") |
|
|
|
|
|
|
|
yield folium_map_html, None |
|
|
|
|
|
|
|
yield folium_map_html, global_best_display_image |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
example_list = [] |
|
queries_folder = Path("./data/queries") |
|
if queries_folder.exists() and queries_folder.is_dir(): |
|
image_extensions = ["*.jpg", "*.jpeg", "*.png"] |
|
image_files = [] |
|
for ext in image_extensions: |
|
image_files.extend(queries_folder.glob(ext)) |
|
|
|
if image_files: |
|
num_examples = min(10, len(image_files)) |
|
random_examples = random.sample(image_files, num_examples) |
|
|
|
|
|
example_list = [ |
|
[str(p), str(util_matching.get_footprint_from_path(p))] |
|
for p in random_examples |
|
] |
|
logging.info( |
|
f"Loaded {len(example_list)} examples for Gradio with footprints." |
|
) |
|
else: |
|
logging.warning( |
|
f"No images found in the examples folder: {queries_folder}" |
|
) |
|
else: |
|
logging.warning(f"Examples folder not found: {queries_folder}") |
|
|
|
model_description = """ |
|
## Model Details |
|
This is a public API for inference of the EarthLoc2 model, which implemets the amazing works of: |
|
- EarthLoc (https://earthloc-and-earthmatch.github.io/) |
|
- EarthMatch (https://earthloc-and-earthmatch.github.io/) |
|
- AstroLoc (https://astro-loc.github.io/) |
|
|
|
### Architecture |
|
- DINOv2 base with SALAD aggregator out dim = 3084 |
|
- FAISS index ~ 8gb, indexes 161496 * 4 images (4 rotated versions) from 2021 |
|
|
|
### Training |
|
- Trained on the original EarthLoc dataset (zooms 9,10,11), in range -60,60 latitude, polar regions not supported |
|
- Training included additional queries which were not part of the test/val sets |
|
- 5000 iterations with a batch size of 96 |
|
|
|
### Performance |
|
- Achieves R@10 = 90.6 on the original EarthLoc test and val sets (when retrieving against whole db as is) |
|
- Overall performance is around 10% worse than AstroLoc (https://9d214e4bc329a5c3f9.gradio.live/) |
|
- Particularly not working well on very small or very large areas. |
|
|
|
### Matching |
|
- Uses the Xfeat_steerers matcher with 2048 maximal number of keypoints, we recommend Master with 2048 if you have access to GPU (we are too poor for it). |
|
|
|
""" |
|
|
|
|
|
theme = gr.themes.Soft( |
|
primary_hue=gr.themes.colors.blue, |
|
font=gr.themes.GoogleFont("Inter"), |
|
).set( |
|
button_primary_background_fill="*primary_900", |
|
button_primary_background_fill_hover="*primary_700", |
|
) |
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
gr.Markdown("# Aerial Photography Locator ") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=2): |
|
image_input = gr.Image( |
|
type="pil", |
|
label="Aerial photos of Earth", |
|
height=400, |
|
) |
|
|
|
hidden_footprint_text = gr.Textbox( |
|
visible=False, label="Query Footprint" |
|
) |
|
|
|
slider = gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=10, |
|
label="Number of Candidates to Process", |
|
info="The higher this number to more likely the model is to find a match, however it takes longer to find it. Expect around 5 second more compute per candidate." |
|
) |
|
|
|
submit_btn = gr.Button("Localize Image", variant="primary") |
|
|
|
gr.Examples( |
|
examples=example_list, |
|
inputs=[image_input, hidden_footprint_text], |
|
label="Example Queries", |
|
examples_per_page=10, |
|
cache_examples=False, |
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
map_output = gr.HTML(label="Final Footprint Map") |
|
image_output = gr.Image( |
|
type="pil", |
|
label="Best Matching Candidate", |
|
height=400, |
|
show_download_button=True, |
|
) |
|
gr.Markdown(model_description) |
|
|
|
|
|
submit_btn.click( |
|
fn=search_and_retrieve, |
|
inputs=[image_input, hidden_footprint_text, slider], |
|
outputs=[map_output, image_output], |
|
) |
|
|
|
demo.launch(share=True) |