|
import gradio as gr |
|
import torch |
|
import faiss |
|
import numpy as np |
|
import pandas as pd |
|
import folium |
|
from PIL import Image |
|
from pathlib import Path |
|
import torchvision.transforms as tfm |
|
from torchvision.transforms import functional as F |
|
import logging |
|
import sys |
|
import io |
|
import base64 |
|
import random |
|
import ast |
|
import webdataset as wds |
|
import os |
|
import pickle |
|
from functools import lru_cache |
|
import tarfile |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
from models.apl_model_dinov2 import DINOv2FeatureExtractor |
|
|
|
sys.path.append(str(Path("image-matching-models"))) |
|
sys.path.append(str(Path("image-matching-models/matching/third_party"))) |
|
|
|
import util_matching |
|
from matching import get_matcher |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
def ensure_files_exist(): |
|
|
|
Path("./faiss_index").mkdir(exist_ok=True) |
|
Path("./data/webdataset_shards").mkdir(parents=True, exist_ok=True) |
|
"""Check for required files and download if missing""" |
|
|
|
if not Path("faiss_index/faiss_index_2021.bin").exists(): |
|
print("Downloading FAISS index...") |
|
|
|
hf_hub_download( |
|
repo_id='pawlo2013/EarthLoc2_FAISS', |
|
filename="faiss_index.bin", |
|
local_dir="./faiss_index", |
|
token=HF_TOKEN, |
|
repo_type="dataset" |
|
) |
|
|
|
|
|
shard_dir = Path("data/webdataset_shards") |
|
required_shards = [f"shard-{i:06d}.tar" for i in range(11)] |
|
required_indices = [f"{s}.index" for s in required_shards] |
|
|
|
missing_files = [ |
|
f for f in required_shards + required_indices |
|
if not (shard_dir / f).exists() |
|
] |
|
|
|
if missing_files: |
|
print(f"Downloading {len(missing_files)} missing shard files...") |
|
snapshot_download( |
|
repo_id="pawlo2013/EarthLoc_2021_Database", |
|
local_dir=shard_dir, |
|
allow_patterns="*.tar*", |
|
token=HF_TOKEN, |
|
repo_type="dataset" |
|
) |
|
|
|
|
|
|
|
ensure_files_exist() |
|
|
|
|
|
|
|
|
|
MODEL_CHECKPOINT_PATH = Path("weights/best_model_95.6.torch") |
|
FAISS_INDEX_PATH = Path("faiss_index/faiss_index.bin") |
|
CSV_MAPPING_PATH = Path("faiss_index/faiss_index_webdataset.csv") |
|
|
|
DEVICE = "cpu" |
|
MATCHING_IMG_SIZE = 512 |
|
logging.info(f"Using device: {DEVICE}") |
|
|
|
for path, desc in [ |
|
(MODEL_CHECKPOINT_PATH, "Model checkpoint"), |
|
(FAISS_INDEX_PATH, "FAISS index"), |
|
(CSV_MAPPING_PATH, "Path mapping CSV"), |
|
]: |
|
if not path.exists(): |
|
raise FileNotFoundError(f"{desc} not found at: {path}") |
|
|
|
MODEL_NAME = "xfeat_steerers" |
|
matcher = get_matcher(MODEL_NAME, device=DEVICE, max_num_keypoints=2048) |
|
if MODEL_NAME == "xfeat_steerers": |
|
matcher.model.dev = DEVICE |
|
|
|
|
|
mapping_df = pd.read_csv(CSV_MAPPING_PATH, index_col="faiss_index") |
|
parsed = mapping_df["key"].str.extract(r"@(?P<z>\d{1,2})_(?P<r>\d{1,5})_(?P<c>\d{1,5})@").astype("int32") |
|
mapping_df = mapping_df.join(parsed) |
|
|
|
|
|
|
|
logging.info(f"Loaded mapping CSV with {len(mapping_df)} entries.") |
|
|
|
|
|
shard_cache = {} |
|
|
|
def get_shard_dataset(shard_path): |
|
"""Load or get cached WebDataset for a shard path.""" |
|
if shard_path not in shard_cache: |
|
shard_cache[shard_path] = wds.WebDataset(shard_path, handler=wds.warn_and_continue) |
|
return shard_cache[shard_path] |
|
|
|
@lru_cache(maxsize=100) |
|
def load_index(index_path): |
|
with open(index_path, "rb") as f: |
|
return pickle.load(f) |
|
|
|
def load_image_from_shard(key): |
|
row = mapping_df[mapping_df["key"] == key] |
|
if row.empty: |
|
return None |
|
|
|
shard_path = row.iloc[0]["shard_path"] |
|
index_path = shard_path + ".index" |
|
|
|
if not os.path.exists(index_path): |
|
return _load_via_linear_scan(shard_path, key) |
|
|
|
try: |
|
index = load_index(index_path) |
|
offset = index.get(key) |
|
if offset is None: |
|
return None |
|
|
|
with open(shard_path, "rb") as f: |
|
f.seek(offset) |
|
with tarfile.open(fileobj=f) as tar: |
|
member = tar.next() |
|
if member and member.name.startswith(key): |
|
jpg_file = tar.extractfile(member) |
|
return Image.open(io.BytesIO(jpg_file.read())).convert("RGB") |
|
return None |
|
except Exception as e: |
|
logging.error(f"Error loading {key}: {str(e)}") |
|
return _load_via_linear_scan(shard_path, key) |
|
|
|
|
|
def _load_via_linear_scan(shard_path, key): |
|
dataset = get_shard_dataset(shard_path) |
|
for sample in dataset: |
|
if sample["__key__"] == key: |
|
if img_bytes := sample.get("jpg"): |
|
return Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
return None |
|
|
|
def pil_to_base64(image): |
|
"""Convert a PIL image to a base64-encoded string for HTML embedding.""" |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
return f"data:image/png;base64,{img_str}" |
|
|
|
def create_map( |
|
final_footprint, candidate_num, filename, inliers, query_footprint=None |
|
): |
|
""" |
|
Create and return a Folium map (as HTML string) showing the final footprint (blue) |
|
and optionally the query's ground truth footprint (orange). |
|
""" |
|
if final_footprint: |
|
center = final_footprint[0] |
|
zoom = 10 |
|
elif query_footprint: |
|
center = query_footprint[0] |
|
zoom = 10 |
|
else: |
|
center = [0, 0] |
|
zoom = 2 |
|
|
|
m = folium.Map(location=center, zoom_start=zoom) |
|
|
|
if query_footprint: |
|
folium.Polygon( |
|
locations=query_footprint, |
|
popup="Ground Truth Query Footprint", |
|
color="orange", |
|
fill=True, |
|
fill_color="orange", |
|
fill_opacity=0.4, |
|
).add_to(m) |
|
|
|
if final_footprint: |
|
footprint_text = "\n".join( |
|
[f"({lat:.4f}, {lon:.4f})" for lat, lon in final_footprint] |
|
) |
|
popup_text = ( |
|
f"Predicted Footprint:<br>{footprint_text}<br><br>" |
|
f"Candidate: {candidate_num}<br>Inliers: {inliers}" |
|
) |
|
|
|
folium.Polygon( |
|
locations=final_footprint, |
|
popup=popup_text, |
|
color="blue", |
|
fill=True, |
|
fill_color="cyan", |
|
fill_opacity=0.5, |
|
).add_to(m) |
|
|
|
folium.Marker( |
|
location=[final_footprint[0][0], final_footprint[0][1]], |
|
popup=f"Footprint Coordinates:<br>{footprint_text}", |
|
icon=folium.Icon(color="blue"), |
|
).add_to(m) |
|
elif not query_footprint: |
|
folium.Marker( |
|
location=[0, 0], |
|
popup="No valid location found.", |
|
icon=folium.Icon(color="red"), |
|
).add_to(m) |
|
|
|
return m._repr_html_() |
|
|
|
def parse_zoom_row_col_from_key(key: str): |
|
""" |
|
Extract zoom, row, col from the complex key string. |
|
The key format is like: |
|
2021_10_30_90_@34,30714@92,81250@35,46067@92,81250@35,46067@94,21875@34,30714@94,21875@10_0404_0776@2021@34,88391@93,51562@16489@0@ |
|
|
|
The zoom, row, col are expected to be in the last underscore-separated fields, |
|
specifically the last three fields before the final '@' or at the end. |
|
|
|
This function tries to extract zoom, row, col as integers. |
|
""" |
|
try: |
|
|
|
parts = util_matching.get_image_metadata_from_path(key) |
|
|
|
|
|
image_id_str = parts[9] |
|
|
|
parts = image_id_str.split("_") |
|
zoom = int(parts[0]) |
|
row = int(parts[1]) |
|
col = int(parts[2]) |
|
|
|
|
|
return zoom, row, col |
|
except Exception as e: |
|
raise ValueError(f"Failed to parse zoom,row,col from key: {key}") from e |
|
|
|
|
|
def get_surrounding_tiles_sharded(candidate_key, zoom): |
|
""" |
|
Given a candidate key, find all keys in mapping_df with the same zoom, |
|
and row/col within ±4 offsets, then load images from shards. |
|
Return list of (img, (row, col), key) sorted by row, col. |
|
""" |
|
try: |
|
zoom, row, col = parse_zoom_row_col_from_key(candidate_key) |
|
except Exception as e: |
|
logging.warning(f"Failed to parse candidate key {candidate_key}: {e}") |
|
return [] |
|
|
|
row_offsets = [-4, 0, 4] |
|
col_offsets = [-4, 0, 4] |
|
|
|
desired_rows = {row + r for r in row_offsets} |
|
desired_cols = {col + c for c in col_offsets} |
|
|
|
|
|
|
|
|
|
mask = ( |
|
(mapping_df["z"] == zoom) & |
|
(mapping_df["r"].isin(desired_rows)) & |
|
(mapping_df["c"].isin(desired_cols)) |
|
) |
|
|
|
matched_rows = mapping_df[mask] |
|
|
|
tiles = [] |
|
seen_positions = set() |
|
|
|
for _, row_data in matched_rows.iterrows(): |
|
k = row_data["key"] |
|
try: |
|
_, r, c = parse_zoom_row_col_from_key(k) |
|
except Exception: |
|
continue |
|
|
|
if (r, c) in seen_positions: |
|
continue |
|
|
|
img = load_image_from_shard(k) |
|
if img is not None: |
|
tiles.append((img, (r, c), k)) |
|
seen_positions.add((r, c)) |
|
|
|
tiles.sort(key=lambda t: (t[1][0], t[1][1])) |
|
return tiles |
|
|
|
def compose_tiles_ordered_sharded(tiles, tile_size, candidate_indices): |
|
""" |
|
Compose a 3x3 grid image from tiles loaded from shards. |
|
Missing tiles replaced with blank. |
|
""" |
|
candidate_row, candidate_col = candidate_indices |
|
grid_img = Image.new("RGB", (tile_size[0] * 3, tile_size[1] * 3)) |
|
blank = Image.new("RGB", tile_size, color=(0, 0, 0)) |
|
|
|
tile_dict = {(rc[0], rc[1]): img for img, rc, key in tiles if img is not None} |
|
|
|
for i, row_offset in enumerate([-4, 0, 4]): |
|
for j, col_offset in enumerate([-4, 0, 4]): |
|
desired_row = candidate_row + row_offset |
|
desired_col = candidate_col + col_offset |
|
img = tile_dict.get((desired_row, desired_col), blank) |
|
|
|
if img.mode != "RGB": |
|
img = img.convert("RGB") |
|
|
|
img_resized = tfm.Resize(tile_size, antialias=True)(img).copy() |
|
|
|
|
|
|
|
grid_img.paste(img_resized, (j * tile_size[0], i * tile_size[1])) |
|
return grid_img |
|
|
|
def run_matching(query_image, candidate_image, base_footprint): |
|
local_fm = None |
|
viz_params = None |
|
|
|
for iteration in range(4): |
|
( |
|
num_inliers, |
|
local_fm, |
|
predicted_footprint, |
|
pretty_footprint, |
|
) = util_matching.estimate_footprint( |
|
local_fm, |
|
query_image, |
|
candidate_image, |
|
matcher, |
|
base_footprint, |
|
HW=MATCHING_IMG_SIZE, |
|
viz_params=viz_params, |
|
) |
|
|
|
if num_inliers == -1 or num_inliers is None: |
|
return -1, [] |
|
|
|
if hasattr(predicted_footprint, "tolist"): |
|
best_footprint = predicted_footprint.tolist() |
|
else: |
|
best_footprint = predicted_footprint |
|
|
|
return num_inliers, best_footprint |
|
|
|
|
|
|
|
logging.info("Loading assets. This may take a moment...") |
|
|
|
try: |
|
model = DINOv2FeatureExtractor( |
|
model_type="vit_base_patch14_reg4_dinov2.lvd142m", |
|
num_of_layers_to_unfreeze=0, |
|
desc_dim=768, |
|
aggregator_type="SALAD", |
|
) |
|
logging.info(f"Loading model checkpoint from {MODEL_CHECKPOINT_PATH}...") |
|
model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE) |
|
model.load_state_dict(model_state_dict) |
|
model = model.to(DEVICE) |
|
model.eval() |
|
logging.info("DINOv2 model and checkpoint loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to load the model: {e}") |
|
raise |
|
|
|
faiss_index = faiss.read_index(str(FAISS_INDEX_PATH)) |
|
num_db_images = faiss_index.ntotal // 4 |
|
logging.info( |
|
f"FAISS index loaded. Contains {faiss_index.ntotal} vectors for {num_db_images} unique images." |
|
) |
|
|
|
image_transform = tfm.Compose( |
|
[ |
|
tfm.Resize((model.image_size, model.image_size), antialias=True), |
|
tfm.ToTensor(), |
|
tfm.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
logging.info("Assets loaded. Gradio app is ready.") |
|
|
|
|
|
|
|
def search_and_retrieve( |
|
query_image: Image.Image, query_footprint_str: str, num_results: int |
|
): |
|
progress = gr.Progress() |
|
query_footprint = None |
|
if query_footprint_str: |
|
try: |
|
print(query_footprint_str) |
|
query_footprint = ast.literal_eval(query_footprint_str) |
|
|
|
query_footprint = [list(coord) for coord in query_footprint] |
|
except (ValueError, SyntaxError): |
|
logging.warning("Could not parse query footprint string.") |
|
query_footprint = None |
|
|
|
if query_image is None: |
|
yield create_map(None, None, None, None), None |
|
return |
|
|
|
progress(0.1, desc="Preprocessing query") |
|
if query_image.mode == "RGBA": |
|
query_image = query_image.convert("RGB") |
|
|
|
image_tensor = image_transform(query_image).to(DEVICE) |
|
with torch.no_grad(): |
|
descriptor = model(image_tensor.unsqueeze(0)) |
|
descriptor_np = descriptor.cpu().numpy() |
|
|
|
|
|
progress(0.2, desc=f"Searching database for {num_results} neighbors") |
|
distances, indices = faiss_index.search(descriptor_np, num_results) |
|
flat_indices = indices.flatten() |
|
|
|
global_best_inliers = -1 |
|
global_best_footprint = None |
|
global_candidate_num = None |
|
global_filename = None |
|
global_best_display_image = None |
|
|
|
candidate_infos = [] |
|
processed_image_indices = set() |
|
|
|
query_tensor = tfm.ToTensor()( |
|
tfm.Resize((MATCHING_IMG_SIZE, MATCHING_IMG_SIZE), antialias=True)(query_image) |
|
) |
|
|
|
|
|
|
|
|
|
progress(0.4, desc="Processing candidates") |
|
for faiss_idx in flat_indices: |
|
|
|
image_index = faiss_idx % num_db_images |
|
|
|
|
|
best_rotation_index = faiss_idx // num_db_images |
|
query_tensor = F.rotate(query_tensor, [0, -90, -180, -270][best_rotation_index] ) |
|
|
|
if image_index in processed_image_indices: |
|
continue |
|
|
|
processed_image_indices.add(image_index) |
|
candidate_num = len(candidate_infos) + 1 |
|
|
|
try: |
|
candidate_row = mapping_df.loc[int(image_index)] |
|
except Exception as e: |
|
logging.warning(f"Failed to get candidate info for index {image_index}: {e}") |
|
continue |
|
|
|
candidate_key = candidate_row["key"] |
|
|
|
candidate_path_str = candidate_row["local_path"] |
|
|
|
shard_path = candidate_row['shard_path'] |
|
|
|
base_footprint = util_matching.path_to_footprint(Path(candidate_path_str)) |
|
|
|
|
|
try: |
|
|
|
|
|
parts = util_matching.get_image_metadata_from_path(candidate_key) |
|
|
|
image_id_str = parts[9] |
|
|
|
parts = image_id_str.split("_") |
|
zoom = int(parts[0]) |
|
candidate_row_idx = int(parts[1]) |
|
candidate_col_idx = int(parts[2]) |
|
|
|
|
|
except Exception as e: |
|
logging.warning(f"Failed to parse candidate key {candidate_key}: {e}") |
|
continue |
|
debug_dir = Path("debug_tiles") |
|
debug_dir.mkdir(exist_ok=True) |
|
|
|
tiles = get_surrounding_tiles_sharded(candidate_key, zoom) |
|
composite_img = compose_tiles_ordered_sharded( |
|
tiles, (1024, 1024), (candidate_row_idx, candidate_col_idx) |
|
) |
|
|
|
display_img = F.rotate( |
|
composite_img, [0, 90, 180, 270][best_rotation_index] |
|
) |
|
|
|
|
|
|
|
|
|
candidate_img_tensor = tfm.ToTensor()(composite_img) |
|
candidate_img_tensor = tfm.Resize( |
|
(MATCHING_IMG_SIZE * 3, MATCHING_IMG_SIZE * 3), antialias=True |
|
)(candidate_img_tensor) |
|
|
|
|
|
candidate_img_tensor = candidate_img_tensor.to(DEVICE) |
|
|
|
|
|
|
|
progress( |
|
0.5 + len(candidate_infos) / num_results * 0.4, |
|
desc=f"Running matching for candidate {candidate_num}", |
|
) |
|
|
|
best_inliers, best_footprint = run_matching( |
|
query_tensor, candidate_img_tensor, base_footprint |
|
) |
|
|
|
|
|
if best_inliers > -1: |
|
candidate_infos.append( |
|
{ |
|
"candidate_num": candidate_num, |
|
"filename": Path(candidate_path_str).name, |
|
"inliers": best_inliers, |
|
"display_image": display_img, |
|
"footprint": best_footprint, |
|
} |
|
) |
|
|
|
if best_inliers > global_best_inliers: |
|
global_best_inliers = best_inliers |
|
global_best_footprint = best_footprint |
|
global_candidate_num = candidate_num |
|
global_filename = Path(candidate_path_str).name |
|
global_best_display_image = display_img |
|
|
|
progress(0.9, desc="Finalizing results") |
|
|
|
folium_map_html = create_map( |
|
global_best_footprint, |
|
global_candidate_num, |
|
global_filename, |
|
global_best_inliers, |
|
query_footprint=query_footprint, |
|
) |
|
progress(1, desc="Done") |
|
|
|
yield folium_map_html, None |
|
yield folium_map_html, global_best_display_image |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
example_list = [] |
|
google_examples = [] |
|
queries_folder = Path("./data/queries") |
|
if queries_folder.exists() and queries_folder.is_dir(): |
|
image_extensions = ["*.jpg", "*.jpeg", "*.png"] |
|
image_files = [] |
|
for ext in image_extensions: |
|
image_files.extend(queries_folder.glob(ext)) |
|
|
|
if image_files: |
|
num_examples = min(10, len(image_files)) |
|
random_examples = random.sample(image_files, num_examples) |
|
example_list = [ |
|
[str(p), str(util_matching.get_footprint_from_path(p))] |
|
for p in random_examples |
|
] |
|
logging.info( |
|
f"Loaded {len(example_list)} examples for Gradio with footprints." |
|
) |
|
else: |
|
logging.warning( |
|
f"No images found in the examples folder: {queries_folder}" |
|
) |
|
else: |
|
logging.warning(f"Examples folder not found: {queries_folder}") |
|
|
|
google_folder = Path("./data/google_maps_queries") |
|
if google_folder.exists() and google_folder.is_dir(): |
|
image_extensions = ["*.jpg", "*.jpeg", "*.png"] |
|
google_files = [] |
|
for ext in image_extensions: |
|
google_files.extend(google_folder.glob(ext)) |
|
if google_files: |
|
num_google = min(10, len(google_files)) |
|
google_examples = [ |
|
[str(p), str(p.stem).split("_")[0]] |
|
for p in random.sample(google_files, num_google) |
|
] |
|
|
|
model_description = """ |
|
## Model Details |
|
This is a public API for inference of the EarthLoc2 model, which implements the amazing works of: |
|
- EarthLoc (https://earthloc-and-earthmatch.github.io/) |
|
- EarthMatch (https://earthloc-and-earthmatch.github.io/) |
|
- AstroLoc (https://astro-loc.github.io/) |
|
|
|
### Architecture |
|
- DINOv2 base with SALAD aggregator out dim = 3072 |
|
- FAISS index ~ 8gb, indexes 161496 * 4 images (4 rotated versions) from 2021 |
|
|
|
### Training |
|
- Trained on the original EarthLoc dataset (zooms 9,10,11), in range -60,60 latitude, polar regions not supported |
|
- Training included additional queries which were not part of the test/val sets |
|
- 5000 iterations with a batch size of 96 |
|
|
|
### Performance |
|
- Achieves R@10 = 90.6 on the original EarthLoc test and val sets (when retrieving against whole db as is) |
|
- Overall performance is around 10% worse than AstroLoc (https://9d214e4bc329a5c3f9.gradio.live/) |
|
- Works well on satelite images between 1000 sq.km and 50000 sq.km, smaller or higher areas will not produce good results. |
|
### Matching |
|
- Uses the Xfeat_steerers matcher with 2048 maximal number of keypoints, we recommend Master with 2048 if you have access to GPU (we are too poor for it). |
|
""" |
|
|
|
theme = gr.themes.Soft( |
|
primary_hue=gr.themes.colors.blue, |
|
font=gr.themes.GoogleFont("Inter"), |
|
).set( |
|
button_primary_background_fill="*primary_900", |
|
button_primary_background_fill_hover="*primary_700", |
|
) |
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
gr.Markdown("# Aerial Photography Locator ") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
image_input = gr.Image( |
|
type="pil", |
|
label="Aerial Photos of Earth", |
|
height=400, |
|
) |
|
hidden_footprint_text = gr.Textbox( |
|
visible=False, label="Query Footprint" |
|
) |
|
|
|
slider = gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=10, |
|
label="Number of Candidates to Process", |
|
info=( |
|
"The higher this number to more likely the model is to find a match, " |
|
"however it takes longer to find it. Expect around 5 second more compute per candidate." |
|
), |
|
) |
|
|
|
submit_btn = gr.Button("Localize Image", variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
if example_list: |
|
gr.Markdown("### ISS Example Queries") |
|
gr.Examples( |
|
examples=example_list, |
|
inputs=[image_input, hidden_footprint_text], |
|
examples_per_page=5, |
|
cache_examples=False, |
|
) |
|
|
|
with gr.Column(): |
|
if google_examples: |
|
gr.Markdown("### Google Maps Example Queries") |
|
gr.Examples( |
|
examples=google_examples, |
|
inputs=[image_input, hidden_footprint_text], |
|
examples_per_page=5, |
|
cache_examples=False, |
|
) |
|
|
|
with gr.Column(scale=2): |
|
map_output = gr.HTML(label="Final Footprint Map") |
|
image_output = gr.Image( |
|
type="pil", |
|
label="Best Matching Candidate", |
|
height=400, |
|
show_download_button=True, |
|
) |
|
gr.Markdown(model_description) |
|
|
|
submit_btn.click( |
|
fn=search_and_retrieve, |
|
inputs=[image_input, hidden_footprint_text, slider], |
|
outputs=[map_output, image_output], |
|
) |
|
|
|
demo.launch(share=True) |