Glimpse Prediction Networks (GPN)
This repo contains the best pretrained GPN models (B, S, R and RS variants), based on the SimCLR backbone, trained on DeepGaze3 glimpses on COCO scenes.
For replicating the analyses reported in the paper (https://arxiv.org/abs/2511.12715), check the github repository: https://github.com/KietzmannLab/GPN
Usage
First setup the glimpse encoder backbone and the GPN:
from huggingface_hub import snapshot_download
import torch
import sys, os
from torchvision.models import resnet50
from torchvision import transforms
REPO_ID = "novelmartis/GPN"
# download whole repo (code + weights) to a local folder
repo_path = snapshot_download(REPO_ID)
# make repo importable
sys.path.insert(0, repo_path)
from gpn_model import build_gpn
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# setup glimpse encoder - SimCLR-RN50
# obtained from: https://huggingface.co/lightly-ai/simclrv1-imagenet1k-resnet50-1x
ckpt = os.path.join(repo_path, "SimCLR ResNet50 1x.pth")
encoder = resnet50(weights=None)
encoder.load_state_dict(torch.load(ckpt, map_location="cpu")['state_dict'])
preprocess = transforms.Compose([
transforms.Resize(224, antialias=True),
transforms.ConvertImageDtype(torch.float),
])
encoder.to(device)
encoder.eval()
# setting up avgpool activation extraction from the RN50-SimCLR
activation = {} # remember to empty the list after each forward pass
def get_activation(name): # setting hooks
def hook(model, input, output):
activation[name] = output.detach()
return hook
encoder.avgpool.register_forward_hook(get_activation('avgpool'))
# choose GPN variants B/S/R/RS
variant = "R"
ckpt = os.path.join(repo_path, f"gpn_{variant}.pth")
gpn = build_gpn(variant=variant)
gpn.load_state_dict(torch.load(ckpt, map_location="cpu"))
gpn.to(device)
gpn.eval()
Next, gather your N glimpse images for each of the B scenes into a [B,N,H,W,3] uint8 numpy array.
Note that during training, COCO scene images were scaled on their lower side to 256px and a central 256px crop was taken. DeepGaze3 was fed this image to obtain fixations. 91px crops were taken around the fixations as glimpses. I'd advise you to do similar preprocessing steps on your images to obtain 'input_glimpses', unless you have apriori reasons to do otherwise.
Gather the glimpse embeddings:
# obtain 'glimpse_embeddings' [B,N,2048]
input_glimpses_tensor = preprocess(torch.from_numpy(input_glimpses.reshape(B*N,H,W,3)).permute(0, 3, 1, 2)).to(device)
with torch.no_grad():
_ = encoder(input_glimpses_tensor)
glimpse_embeddings = activation['avgpool'].detach().cpu().numpy().squeeze()
glimpse_embeddings = glimpse_embeddings.reshape(B,N,-1)
Finally, pass the glimpse embeddings into the GPN:
# run the gpn on glimpse_embeddings and saccade_inputs [B,N,2]
# set return_all_actvs=True to acquire internal representations
representations, predictions = gpn(torch.from_numpy(glimpse_embeddings).to(device), torch.from_numpy(saccade_inputs).to(device), return_all_actvs=True)
GPN outputs:
- predictions: [B,N,2048] contain the prediction of the T+1 glimpse embedding at each T
- representations is a dict; e.g. representations['lstm_out'], when variant = "R", have the best alignment to ventral and parietal ROI RDMs in NSD.