from os.path import join
import PIL
import numpy as np
import pandas as pd
import reverse_geocoder
from torch.utils.data import Dataset


class GeoDataset(Dataset):
    def __init__(self, image_folder, annotation_file, transformation, tag="image_id"):
        self.image_folder = image_folder
        gt = pd.read_csv(annotation_file, dtype={tag: str})
        files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)])
        gt = gt[gt[tag].isin(files)]
        self.processor = transformation
        self.gt = [
            (g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows()
        ]
        self.tag = tag

    def fid(self, i):
        return self.gt[i][0]

    def latlon(self, i):
        return self.gt[i][1]

    def __len__(self):
        return len(self.gt)

    def __getitem__(self, idx):
        fp = join(self.image_folder, self.gt[idx][0] + ".jpg")
        return self.processor(self, idx, fp)


def load_plonk(path):
    import hydra
    from hydra import initialize, compose
    from models.module import Geolocalizer
    from omegaconf import OmegaConf, open_dict
    from os.path import join
    from hydra.utils import instantiate

    # load config from path
    # make path relative to current_dir
    with initialize(version_base=None, config_path="osv5m__best_model"):
        cfg = compose(config_name="config", overrides=[])

    checkpoint = torch.load(join(path, "last.ckpt"))
    del checkpoint["state_dict"][
        "model.backbone.clip.vision_model.embeddings.position_ids"
    ]
    torch.save(checkpoint, join(path, "last2.ckpt"))

    with open_dict(cfg):
        cfg.checkpoint = join(path, "last2.ckpt")

    cfg.num_classes = 11399
    cfg.model.network.mid.instance.final_dim = cfg.num_classes * 3
    cfg.model.network.head.final_dim = cfg.num_classes * 3
    cfg.model.network.head.instance.quadtree_path = join(path, "quadtree_10_1000.csv")

    cfg.dataset.train_dataset.path = ""
    cfg.dataset.val_dataset.path = ""
    cfg.dataset.test_dataset.path = ""
    cfg.logger.save_dir = ""
    cfg.data_dir = ""
    cfg.root_dir = ""
    cfg.mode = "test"
    cfg.model.network.backbone.instance.path = (
        "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
    )
    transform = instantiate(cfg.dataset.test_transform)
    model = Geolocalizer.load_from_checkpoint(join(path, "last2.ckpt"), cfg=cfg.model)
    os.remove(join(path, "last2.ckpt"))

    @torch.no_grad()
    def inference(model, x):
        return x[0], model.model.backbone({"img": x[1].to(model.device)})[:, 0, :].cpu()

    def collate_fn(batch):
        return [b[0] for b in batch], torch.stack([b[1] for b in batch], dim=0)

    def operate(self, idx, fp):
        proc = self.processor(PIL.Image.open(fp))
        return self.gt[idx][0], proc

    return model, operate, inference, collate_fn


def load_clip(which):
    # We evaluate on:
    # - "openai/clip-vit-base-patch32"
    # - "openai/clip-vit-large-patch14-336"
    # - "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
    # - "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
    # - "geolocal/StreetCLIP"
    from transformers import CLIPProcessor, CLIPModel

    @torch.no_grad()
    def inference(model, img):
        image_ids = img.data.pop("image_id")
        image_input = img.to(model.device)
        image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
        features = model.get_image_features(**image_input)
        features /= features.norm(dim=-1, keepdim=True)
        return image_ids, features.cpu()

    processor = CLIPProcessor.from_pretrained(which)

    def operate(self, idx, fp):
        pil = PIL.Image.open(fp)
        proc = processor(images=pil, return_tensors="pt")
        proc["image_id"] = self.gt[idx][0]
        return proc

    return CLIPModel.from_pretrained(which), operate, inference, None


def load_dino(which):
    # We evaluate on:
    # - 'facebook/dinov2-large'
    from transformers import AutoImageProcessor, AutoModel

    @torch.no_grad()
    def inference(model, img):
        image_ids = img.data.pop("image_id")
        image_input = img.to(model.device)
        image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
        features = model(**image_input).last_hidden_state[:, 0]
        features /= features.norm(dim=-1, keepdim=True)
        return image_ids, features.cpu()

    processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")

    def operate(self, idx, fp):
        pil = PIL.Image.open(fp)
        proc = processor(images=pil, return_tensors="pt")
        proc["image_id"] = self.gt[idx][0]
        return proc

    return AutoModel.from_pretrained("facebook/dinov2-large"), operate, inference, None


def get_backbone(name):
    if os.path.isdir(name):
        return load_plonk(name)
    elif "clip" in name.lower():
        return load_clip(name)
    elif "dino" in name.lower():
        return load_dino(name)