import gradio as gr from datasets import load_dataset from itertools import islice import numpy as np from PIL import Image import torch from transformers import CLIPModel, CLIPProcessor import torch.nn.functional as F import os, json, time # ---------- utils ---------- def flux_to_gray(flux_array): a = np.array(flux_array, dtype=np.float32) a = np.squeeze(a) if a.ndim == 3: axis = int(np.argmin(a.shape)) a = np.nanmean(a, axis=axis) a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0) lo = np.nanpercentile(a, 1) hi = np.nanpercentile(a, 99) if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: lo, hi = float(np.nanmin(a)), float(np.nanmax(a)) norm = np.clip((a - lo) / (hi - lo + 1e-9), 0, 1) arr = (norm * 255).astype(np.uint8) return Image.fromarray(arr, mode="L") # ---------- model ---------- model_id = "openai/clip-vit-base-patch32" model = CLIPModel.from_pretrained(model_id) processor = CLIPProcessor.from_pretrained(model_id) model.eval() # ---------- in-memory index ---------- INDEX = { "feats": None, # torch.Tensor [N, 512] "ids": [], # list[str] "thumbs": [], # list[PIL.Image] "bands": [] # list[str] } def build_index(n=200): ds = load_dataset("MultimodalUniverse/jwst", split="train", streaming=True) feats, ids, thumbs, bands = [], [], [], [] for rec in islice(ds, int(n)): pil = flux_to_gray(rec["image"]["flux"]).convert("RGB") t = pil.copy(); t.thumbnail((128, 128)) with torch.no_grad(): inp = processor(images=pil, return_tensors="pt") f = model.get_image_features(**inp) # [1, 512] f = F.normalize(f, p=2, dim=-1)[0] # [512] feats.append(f) ids.append(str(rec.get("object_id"))) bands.append(str(rec["image"].get("band"))) thumbs.append(t) if not feats: return "No records indexed." INDEX["feats"] = torch.stack(feats) # [N, 512] INDEX["ids"] = ids INDEX["thumbs"] = thumbs INDEX["bands"] = bands return f"Index built: {len(ids)} images." def search(text_query, image_query, k=5): if INDEX["feats"] is None: return [], "Build the index first." with torch.no_grad(): if text_query and str(text_query).strip(): inputs = processor(text=[str(text_query).strip()], return_tensors="pt") q = model.get_text_features(**inputs) # [1, 512] elif image_query is not None: pil = image_query.convert("RGB") inputs = processor(images=pil, return_tensors="pt") q = model.get_image_features(**inputs) # [1, 512] else: return [], "Enter text or upload an image." q = F.normalize(q, p=2, dim=-1)[0] # [512] sims = (INDEX["feats"] @ q).cpu() # [N] k = min(int(k), sims.shape[0]) topk = torch.topk(sims, k=k) items = [] for idx in topk.indices.tolist(): cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}" items.append((INDEX["thumbs"][idx], cap)) return items, f"Returned {k} results." # ---------- evaluation helpers ---------- def _search_topk_for_eval(text_query, image_query, k=5): if INDEX["feats"] is None: return [], [], "Build the index first." with torch.no_grad(): if text_query and str(text_query).strip(): inputs = processor(text=[str(text_query).strip()], return_tensors="pt") q = model.get_text_features(**inputs) elif image_query is not None: pil = image_query.convert("RGB") inputs = processor(images=pil, return_tensors="pt") q = model.get_image_features(**inputs) else: return [], [], "Enter text or upload an image." q = F.normalize(q, p=2, dim=-1)[0] sims = (INDEX["feats"] @ q).cpu() k = min(int(k), sims.shape[0]) topk = torch.topk(sims, k=k) idxs = topk.indices.tolist() # reuse thumbs and captions like your main search items = [] for idx in idxs: cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}" items.append((INDEX["thumbs"][idx], cap)) return items, idxs, f"Eval preview: top {k} ready." def _format_eval_summary(query, k, hits, p_at_k): lines = [] lines.append(f"Query: {query or '[image query]'}") lines.append(f"K: {k}") lines.append(f"Relevant marked: {hits} of {k}") lines.append(f"Precision@{k}: {p_at_k:.2f}") lines.append("Saved to eval_runs.jsonl") return "\n".join(lines) def _save_eval_run(record): try: with open("eval_runs.jsonl", "a", encoding="utf-8") as f: f.write(json.dumps(record) + "\n") except Exception: pass def _compute_avg_from_file(): try: total = 0.0 n = 0 with open("eval_runs.jsonl", "r", encoding="utf-8") as f: for line in f: rec = json.loads(line) if "precision_at_k" in rec: total += float(rec["precision_at_k"]) n += 1 if n == 0: return "No runs recorded yet." return f"Macro average Precision@K across {n} runs: {total/n:.2f}" except FileNotFoundError: return "No eval_runs.jsonl yet. Run at least one evaluation." # ---------- UI ---------- with gr.Blocks() as demo: gr.Markdown("JWST multimodal search build the index") # Build n = gr.Slider(50, 1000, value=200, step=10, label="How many images to index") build_btn = gr.Button("Build index") status = gr.Textbox(label="Status", lines=2) build_btn.click(build_index, inputs=n, outputs=status) # Search gr.Markdown("Search the index with text or an example image") q_text = gr.Textbox(label="Text query", placeholder="e.g., spiral galaxy") q_img = gr.Image(label="Image query", type="pil") k = gr.Slider(1, 12, value=6, step=1, label="Top K") search_btn = gr.Button("Search") gallery = gr.Gallery(label="Results", columns=6, height=300) info2 = gr.Textbox(label="Search status", lines=1) search_btn.click(search, inputs=[q_text, q_img, k], outputs=[gallery, info2]) # ---------- Evaluation (guided) ---------- with gr.Accordion("Evaluation", open=False): gr.Markdown( "### What this does\n" "We evaluate text to image retrieval using Precision at K.\n" "Steps: pick a preset or type a query, click **Run and label**, " "tick the results that match the rule shown, then click **Compute metrics**. " "Each run is saved so you can average later." ) # Preset prompts with plain English relevance rules PRESETS = { "star with spikes": "Relevant = bright point source with clear 4 to 6 diffraction spikes. Minimal extended glow.", "edge-on galaxy": "Relevant = thin elongated streak. Looks like a narrow line. No round diffuse blob.", "spiral galaxy": "Relevant = visible spiral arms or a spiral outline. Arms can be faint.", "diffuse nebula": "Relevant = fuzzy cloud like structure. No sharp round core.", "ring or annulus": "Relevant = ring or donut shape is the main feature.", "two merging objects": "Relevant = two bright blobs touching or overlapping." } with gr.Row(): preset = gr.Dropdown(choices=list(PRESETS.keys()), label="Preset query (optional)") eval_k = gr.Slider(1, 12, value=6, step=1, label="K for evaluation") eval_query = gr.Textbox(label="Evaluation query (you can edit or type your own)") eval_img = gr.Image(label="Evaluation image (optional)", type="pil") rules_md = gr.Markdown() run_and_label = gr.Button("Run and label this query") eval_gallery = gr.Gallery(label="Eval top K results", columns=6, height=300) relevant_picker = gr.CheckboxGroup(label="Select indices of relevant results (1..K)") eval_md = gr.Markdown() # state bag for this panel eval_state = gr.State({"result_indices": [], "k": 5, "query": ""}) def _on_preset_change(name): if name in PRESETS: return gr.update(value=name), PRESETS[name] return gr.update(), "" preset.change(fn=_on_preset_change, inputs=preset, outputs=[eval_query, rules_md]) # uses helper _search_topk_for_eval defined above def _run_eval_query(q_txt, q_img_in, k_in, state): items, idxs, _ = _search_topk_for_eval(q_txt, q_img_in, k_in) state["result_indices"] = idxs state["k"] = int(k_in) state["query"] = q_txt if (q_txt and q_txt.strip()) else "[image query]" choice_labels = [str(i+1) for i in range(len(idxs))] help_text = PRESETS.get((q_txt or "").strip().lower(), "Mark results that match the concept you typed.") return (items, gr.update(choices=choice_labels, value=[]), f"**Relevance rule:** {help_text}\n\nThen click **Compute metrics**.", state) run_and_label.click( fn=_run_eval_query, inputs=[eval_query, eval_img, eval_k, eval_state], outputs=[eval_gallery, relevant_picker, eval_md, eval_state] ) compute_btn = gr.Button("Compute metrics") # uses helpers _save_eval_run and _format_eval_summary defined above def _compute_pk(selected_indices, state): k_val = int(state.get("k", 5)) query = state.get("query", "") hits = len(selected_indices) p_at_k = hits / max(k_val, 1) record = { "ts": int(time.time()), "query": query, "k": k_val, "relevant_indices": sorted([int(s) for s in selected_indices]), "precision_at_k": p_at_k } _save_eval_run(record) return _format_eval_summary(query, k_val, hits, p_at_k) compute_btn.click(fn=_compute_pk, inputs=[relevant_picker, eval_state], outputs=eval_md) avg_btn = gr.Button("Compute average across saved runs") avg_md = gr.Markdown() avg_btn.click(fn=_compute_avg_from_file, outputs=avg_md) demo.launch()