--- license: mit pipeline_tag: image-feature-extraction tags: - scene-representation - active-vision - lstm --- # Glimpse Prediction Networks (GPN) This repo contains the best pretrained GPN models (B, S, R and RS variants), based on the SimCLR backbone, trained on DeepGaze3 glimpses on COCO scenes. For replicating the analyses reported in the paper (https://arxiv.org/abs/2511.12715), check the github repository: https://github.com/KietzmannLab/GPN ## Usage First setup the glimpse encoder backbone and the GPN: ``` from huggingface_hub import snapshot_download import torch import sys, os from torchvision.models import resnet50 from torchvision import transforms REPO_ID = "novelmartis/GPN" # download whole repo (code + weights) to a local folder repo_path = snapshot_download(REPO_ID) # make repo importable sys.path.insert(0, repo_path) from gpn_model import build_gpn device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") # setup glimpse encoder - SimCLR-RN50 # obtained from: https://huggingface.co/lightly-ai/simclrv1-imagenet1k-resnet50-1x ckpt = os.path.join(repo_path, "SimCLR ResNet50 1x.pth") encoder = resnet50(weights=None) encoder.load_state_dict(torch.load(ckpt, map_location="cpu")['state_dict']) preprocess = transforms.Compose([ transforms.Resize(224, antialias=True), transforms.ConvertImageDtype(torch.float), ]) encoder.to(device) encoder.eval() # setting up avgpool activation extraction from the RN50-SimCLR activation = {} # remember to empty the list after each forward pass def get_activation(name): # setting hooks def hook(model, input, output): activation[name] = output.detach() return hook encoder.avgpool.register_forward_hook(get_activation('avgpool')) # choose GPN variants B/S/R/RS variant = "R" ckpt = os.path.join(repo_path, f"gpn_{variant}.pth") gpn = build_gpn(variant=variant) gpn.load_state_dict(torch.load(ckpt, map_location="cpu")) gpn.to(device) gpn.eval() ``` Next, gather your N glimpse images for each of the B scenes into a [B,N,H,W,3] uint8 numpy array. Note that during training, COCO scene images were scaled on their lower side to 256px and a central 256px crop was taken. DeepGaze3 was fed this image to obtain fixations. 91px crops were taken around the fixations as glimpses. I'd advise you to do similar preprocessing steps on your images to obtain 'input_glimpses', unless you have apriori reasons to do otherwise. Gather the glimpse embeddings: ``` # obtain 'glimpse_embeddings' [B,N,2048] input_glimpses_tensor = preprocess(torch.from_numpy(input_glimpses.reshape(B*N,H,W,3)).permute(0, 3, 1, 2)).to(device) with torch.no_grad(): _ = encoder(input_glimpses_tensor) glimpse_embeddings = activation['avgpool'].detach().cpu().numpy().squeeze() glimpse_embeddings = glimpse_embeddings.reshape(B,N,-1) ``` Finally, pass the glimpse embeddings into the GPN: ``` # run the gpn on glimpse_embeddings and saccade_inputs [B,N,2] # set return_all_actvs=True to acquire internal representations representations, predictions = gpn(torch.from_numpy(glimpse_embeddings).to(device), torch.from_numpy(saccade_inputs).to(device), return_all_actvs=True) ``` GPN outputs: - predictions: [B,N,2048] contain the prediction of the T+1 glimpse embedding at each T - representations is a dict; e.g. representations['lstm_out'], when variant = "R", have the best alignment to ventral and parietal ROI RDMs in NSD.