Commit
·
2b43c93
1
Parent(s):
78711ff
Add inference code
Browse files- assets/model_architecture.png +3 -0
- config.json +26 -0
- examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif +3 -0
- examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif +3 -0
- examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif +3 -0
- examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif +3 -0
- inference.py +528 -0
- prithvi_mae.py +766 -0
- requirements.txt +5 -0
assets/model_architecture.png
ADDED
|
Git LFS Details
|
config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": "prithvi_eo_v2_tiny",
|
| 3 |
+
"num_features": 192,
|
| 4 |
+
"pretrained_cfg": {
|
| 5 |
+
"img_size": 224,
|
| 6 |
+
"num_frames": 4,
|
| 7 |
+
"patch_size": [1, 16, 16],
|
| 8 |
+
"in_chans": 6,
|
| 9 |
+
"embed_dim": 192,
|
| 10 |
+
"depth": 12,
|
| 11 |
+
"num_heads": 3,
|
| 12 |
+
"decoder_embed_dim": 512,
|
| 13 |
+
"decoder_depth": 8,
|
| 14 |
+
"decoder_num_heads": 16,
|
| 15 |
+
"mlp_ratio": 4,
|
| 16 |
+
"coords_encoding": ["time", "location"],
|
| 17 |
+
"coords_scale_learn": true,
|
| 18 |
+
"mask_ratio": 0.75,
|
| 19 |
+
"norm_pix_loss": false,
|
| 20 |
+
"bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
|
| 21 |
+
"mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
|
| 22 |
+
"std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
|
| 23 |
+
"origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny",
|
| 24 |
+
"paper_ids": "arXiv:X.X"
|
| 25 |
+
}
|
| 26 |
+
}
|
examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif
ADDED
|
|
Git LFS Details
|
examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif
ADDED
|
|
Git LFS Details
|
examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif
ADDED
|
|
Git LFS Details
|
examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif
ADDED
|
|
Git LFS Details
|
inference.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
import re
|
| 6 |
+
import datetime
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import rasterio
|
| 10 |
+
import torch
|
| 11 |
+
import yaml
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from functools import partial
|
| 15 |
+
|
| 16 |
+
from torch.distributed.checkpoint import state_dict
|
| 17 |
+
|
| 18 |
+
from prithvi_mae import PrithviMAE
|
| 19 |
+
|
| 20 |
+
NO_DATA = -9999
|
| 21 |
+
NO_DATA_FLOAT = 0.0001
|
| 22 |
+
OFFSET = 0
|
| 23 |
+
PERCENTILE = 99.9
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process_channel_group(orig_img, new_img, channels, mean, std):
|
| 27 |
+
"""Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
| 28 |
+
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
| 29 |
+
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
| 33 |
+
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
| 34 |
+
channels: list of indices representing RGB channels.
|
| 35 |
+
mean: list of mean values for each band.
|
| 36 |
+
std: list of std values for each band.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
torch.Tensor with shape (num_channels, height, width) for original image
|
| 40 |
+
torch.Tensor with shape (num_channels, height, width) for the other image
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
| 44 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
| 45 |
+
orig_img = orig_img[channels, ...]
|
| 46 |
+
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
| 47 |
+
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
| 48 |
+
|
| 49 |
+
# Back to original data range
|
| 50 |
+
orig_img = (orig_img * std[channels]) + mean[channels]
|
| 51 |
+
new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
|
| 52 |
+
|
| 53 |
+
# Rescale (enhancing contrast)
|
| 54 |
+
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
| 55 |
+
min_value = OFFSET
|
| 56 |
+
|
| 57 |
+
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
| 58 |
+
new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
|
| 59 |
+
|
| 60 |
+
# No data as zeros
|
| 61 |
+
orig_img[~valid_mask] = 0
|
| 62 |
+
new_img[~valid_mask] = 0
|
| 63 |
+
|
| 64 |
+
return orig_img, new_img
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def read_geotiff(file_path: str):
|
| 68 |
+
"""Read all bands from *file_path* and return image + meta info.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
file_path: path to image file.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
np.ndarray with shape (bands, height, width)
|
| 75 |
+
meta info dict
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
with rasterio.open(file_path) as src:
|
| 79 |
+
img = src.read()
|
| 80 |
+
meta = src.meta
|
| 81 |
+
try:
|
| 82 |
+
coords = src.lnglat()
|
| 83 |
+
except:
|
| 84 |
+
# Cannot read coords
|
| 85 |
+
coords = None
|
| 86 |
+
|
| 87 |
+
return img, meta, coords
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def save_geotiff(image, output_path: str, meta: dict):
|
| 91 |
+
"""Save multi-band image in Geotiff file.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
image: np.ndarray with shape (bands, height, width)
|
| 95 |
+
output_path: path where to save the image
|
| 96 |
+
meta: dict with meta info.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
with rasterio.open(output_path, "w", **meta) as dest:
|
| 100 |
+
for i in range(image.shape[0]):
|
| 101 |
+
dest.write(image[i, :, :], i + 1)
|
| 102 |
+
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _convert_np_uint8(float_image: torch.Tensor):
|
| 107 |
+
image = float_image.numpy() * 255.0
|
| 108 |
+
image = image.astype(dtype=np.uint8)
|
| 109 |
+
|
| 110 |
+
return image
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_example(
|
| 114 |
+
file_paths: List[str],
|
| 115 |
+
mean: List[float],
|
| 116 |
+
std: List[float],
|
| 117 |
+
indices: Union[list[int], None] = None,
|
| 118 |
+
):
|
| 119 |
+
"""Build an input example by loading images in *file_paths*.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
file_paths: list of file paths .
|
| 123 |
+
mean: list containing mean values for each band in the images in *file_paths*.
|
| 124 |
+
std: list containing std values for each band in the images in *file_paths*.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
np.array containing created example
|
| 128 |
+
list of meta info for each image in *file_paths*
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
imgs = []
|
| 132 |
+
metas = []
|
| 133 |
+
temporal_coords = []
|
| 134 |
+
location_coords = []
|
| 135 |
+
|
| 136 |
+
for file in file_paths:
|
| 137 |
+
img, meta, coords = read_geotiff(file)
|
| 138 |
+
|
| 139 |
+
# Rescaling (don't normalize on nodata)
|
| 140 |
+
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
| 141 |
+
if indices is not None:
|
| 142 |
+
img = img[..., indices]
|
| 143 |
+
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
| 144 |
+
|
| 145 |
+
imgs.append(img)
|
| 146 |
+
metas.append(meta)
|
| 147 |
+
if coords is not None:
|
| 148 |
+
location_coords.append(coords)
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
match = re.search(r'(\d{7,8}T\d{6})', file)
|
| 152 |
+
if match:
|
| 153 |
+
year = int(match.group(1)[:4])
|
| 154 |
+
julian_day = match.group(1).split('T')[0][4:]
|
| 155 |
+
if len(julian_day) == 3:
|
| 156 |
+
julian_day = int(julian_day)
|
| 157 |
+
else:
|
| 158 |
+
julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
|
| 159 |
+
temporal_coords.append([year, julian_day])
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f'Could not extract timestamp for {file} ({e})')
|
| 162 |
+
|
| 163 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
| 164 |
+
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
| 165 |
+
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
| 166 |
+
|
| 167 |
+
return imgs, temporal_coords, location_coords, metas
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def run_model(
|
| 171 |
+
model: torch.nn.Module,
|
| 172 |
+
input_data: torch.Tensor,
|
| 173 |
+
temporal_coords: None | torch.Tensor,
|
| 174 |
+
location_coords: None | torch.Tensor,
|
| 175 |
+
mask_ratio: float,
|
| 176 |
+
device: torch.device,
|
| 177 |
+
):
|
| 178 |
+
"""Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
model: MAE model to run.
|
| 182 |
+
input_data: torch.Tensor with shape (B, C, T, H, W).
|
| 183 |
+
mask_ratio: mask ratio to use.
|
| 184 |
+
device: device where model should run.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
3 torch.Tensor with shape (B, C, T, H, W).
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
x = input_data.to(device)
|
| 192 |
+
|
| 193 |
+
_, pred, mask = model(x, temporal_coords, location_coords, mask_ratio)
|
| 194 |
+
|
| 195 |
+
# Create mask and prediction images (un-patchify)
|
| 196 |
+
mask_img = (
|
| 197 |
+
model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
|
| 198 |
+
)
|
| 199 |
+
pred_img = model.unpatchify(pred).detach().cpu()
|
| 200 |
+
|
| 201 |
+
# Mix visible and predicted patches
|
| 202 |
+
rec_img = input_data.clone()
|
| 203 |
+
rec_img[mask_img == 1] = pred_img[
|
| 204 |
+
mask_img == 1
|
| 205 |
+
] # binary mask: 0 is keep, 1 is remove
|
| 206 |
+
|
| 207 |
+
# Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
|
| 208 |
+
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
|
| 209 |
+
|
| 210 |
+
return rec_img, mask_img
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def save_rgb_imgs(
|
| 214 |
+
input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
|
| 215 |
+
):
|
| 216 |
+
"""Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
input_img: input torch.Tensor with shape (C, T, H, W).
|
| 220 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
| 221 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
| 222 |
+
channels: list of indices representing RGB channels.
|
| 223 |
+
mean: list of mean values for each band.
|
| 224 |
+
std: list of std values for each band.
|
| 225 |
+
output_dir: directory where to save outputs.
|
| 226 |
+
meta_data: list of dicts with geotiff meta info.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
for t in range(input_img.shape[1]):
|
| 230 |
+
rgb_orig, rgb_pred = process_channel_group(
|
| 231 |
+
orig_img=input_img[:, t, :, :],
|
| 232 |
+
new_img=rec_img[:, t, :, :],
|
| 233 |
+
channels=channels,
|
| 234 |
+
mean=mean,
|
| 235 |
+
std=std,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
| 239 |
+
|
| 240 |
+
# Saving images
|
| 241 |
+
|
| 242 |
+
save_geotiff(
|
| 243 |
+
image=_convert_np_uint8(rgb_orig),
|
| 244 |
+
output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
|
| 245 |
+
meta=meta_data[t],
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
save_geotiff(
|
| 249 |
+
image=_convert_np_uint8(rgb_pred),
|
| 250 |
+
output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
|
| 251 |
+
meta=meta_data[t],
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
save_geotiff(
|
| 255 |
+
image=_convert_np_uint8(rgb_mask),
|
| 256 |
+
output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
|
| 257 |
+
meta=meta_data[t],
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
| 262 |
+
"""Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
| 266 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
| 267 |
+
mean: list of mean values for each band.
|
| 268 |
+
std: list of std values for each band.
|
| 269 |
+
output_dir: directory where to save outputs.
|
| 270 |
+
meta_data: list of dicts with geotiff meta info.
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
| 274 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
| 275 |
+
|
| 276 |
+
for t in range(rec_img.shape[1]):
|
| 277 |
+
# Back to original data range
|
| 278 |
+
rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
|
| 279 |
+
|
| 280 |
+
mask_img_t = mask_img[:, t, :, :].to(torch.int16)
|
| 281 |
+
|
| 282 |
+
# Saving images
|
| 283 |
+
|
| 284 |
+
save_geotiff(
|
| 285 |
+
image=rec_img_t,
|
| 286 |
+
output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
|
| 287 |
+
meta=meta_data[t],
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
save_geotiff(
|
| 291 |
+
image=mask_img_t,
|
| 292 |
+
output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
|
| 293 |
+
meta=meta_data[t],
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def main(
|
| 298 |
+
data_files: List[str],
|
| 299 |
+
config_path: str,
|
| 300 |
+
checkpoint: str,
|
| 301 |
+
output_dir: str,
|
| 302 |
+
rgb_outputs: bool,
|
| 303 |
+
mask_ratio: float = None,
|
| 304 |
+
input_indices: list[int] = None,
|
| 305 |
+
):
|
| 306 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 307 |
+
|
| 308 |
+
# Get parameters --------
|
| 309 |
+
|
| 310 |
+
import json
|
| 311 |
+
with open(config_path, "r") as f:
|
| 312 |
+
config = yaml.safe_load(f)['pretrained_cfg']
|
| 313 |
+
|
| 314 |
+
batch_size = 1
|
| 315 |
+
bands = config['bands']
|
| 316 |
+
num_frames = len(data_files)
|
| 317 |
+
mean = config['mean']
|
| 318 |
+
std = config['std']
|
| 319 |
+
coords_encoding = config['coords_encoding']
|
| 320 |
+
img_size = config['img_size']
|
| 321 |
+
mask_ratio = mask_ratio or config['mask_ratio']
|
| 322 |
+
|
| 323 |
+
print(
|
| 324 |
+
f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
|
| 325 |
+
)
|
| 326 |
+
if len(data_files) != 4:
|
| 327 |
+
print(
|
| 328 |
+
"The original model was trained for four time steps. \nResults with different numbers of time steps may vary"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if torch.cuda.is_available():
|
| 332 |
+
device = torch.device("cuda")
|
| 333 |
+
else:
|
| 334 |
+
device = torch.device("cpu")
|
| 335 |
+
|
| 336 |
+
print(f"Using {device} device.\n")
|
| 337 |
+
|
| 338 |
+
# Loading data ---------------------------------------------------------------------------------
|
| 339 |
+
|
| 340 |
+
input_data, temporal_coords, location_coords, meta_data = load_example(
|
| 341 |
+
file_paths=data_files, indices=input_indices, mean=mean, std=std
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if len(temporal_coords) != num_frames and 'time' in coords_encoding:
|
| 345 |
+
coords_encoding.pop('time')
|
| 346 |
+
if not len(location_coords) and 'location' in coords_encoding:
|
| 347 |
+
coords_encoding.pop('location')
|
| 348 |
+
|
| 349 |
+
# Create model and load checkpoint -------------------------------------------------------------
|
| 350 |
+
|
| 351 |
+
config.update(
|
| 352 |
+
coords_encoding=coords_encoding,
|
| 353 |
+
num_frames=num_frames,
|
| 354 |
+
in_chans=len(bands),
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
model = PrithviMAE(**config)
|
| 358 |
+
|
| 359 |
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 360 |
+
print(f"\n--> Model has {total_params:,} parameters.\n")
|
| 361 |
+
|
| 362 |
+
model.to(device)
|
| 363 |
+
|
| 364 |
+
state_dict = torch.load(checkpoint, map_location=device, weights_only=True)
|
| 365 |
+
# discard fixed pos_embedding weight
|
| 366 |
+
for k in list(state_dict.keys()):
|
| 367 |
+
if k == 'encoder.pos_embed':
|
| 368 |
+
state_dict[k] = model.encoder.pos_embed
|
| 369 |
+
elif k == 'decoder.decoder_pos_embed':
|
| 370 |
+
state_dict[k] = model.decoder.decoder_pos_embed
|
| 371 |
+
model.load_state_dict(state_dict, strict=True)
|
| 372 |
+
print(f"Loaded checkpoint from {checkpoint}")
|
| 373 |
+
|
| 374 |
+
# Running model --------------------------------------------------------------------------------
|
| 375 |
+
|
| 376 |
+
model.eval()
|
| 377 |
+
channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
|
| 378 |
+
|
| 379 |
+
# Reflect pad if not divisible by img_size
|
| 380 |
+
original_h, original_w = input_data.shape[-2:]
|
| 381 |
+
pad_h = img_size - (original_h % img_size)
|
| 382 |
+
pad_w = img_size - (original_w % img_size)
|
| 383 |
+
input_data = np.pad(
|
| 384 |
+
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Build sliding window
|
| 388 |
+
batch = torch.tensor(input_data, device="cpu")
|
| 389 |
+
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
| 390 |
+
h1, w1 = windows.shape[3:5]
|
| 391 |
+
windows = rearrange(
|
| 392 |
+
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Split into batches if number of windows > batch_size
|
| 396 |
+
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
| 397 |
+
windows = torch.tensor_split(windows, num_batches, dim=0)
|
| 398 |
+
|
| 399 |
+
temporal_coords = torch.Tensor(temporal_coords, device=device).unsqueeze(0)
|
| 400 |
+
location_coords = torch.Tensor(location_coords[0], device=device).unsqueeze(0)
|
| 401 |
+
|
| 402 |
+
# Run model
|
| 403 |
+
rec_imgs = []
|
| 404 |
+
mask_imgs = []
|
| 405 |
+
for x in windows:
|
| 406 |
+
rec_img, mask_img = run_model(model, x, temporal_coords, location_coords, mask_ratio, device)
|
| 407 |
+
rec_imgs.append(rec_img)
|
| 408 |
+
mask_imgs.append(mask_img)
|
| 409 |
+
|
| 410 |
+
rec_imgs = torch.concat(rec_imgs, dim=0)
|
| 411 |
+
mask_imgs = torch.concat(mask_imgs, dim=0)
|
| 412 |
+
|
| 413 |
+
# Build images from patches
|
| 414 |
+
rec_imgs = rearrange(
|
| 415 |
+
rec_imgs,
|
| 416 |
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
| 417 |
+
h=img_size,
|
| 418 |
+
w=img_size,
|
| 419 |
+
b=1,
|
| 420 |
+
c=len(bands),
|
| 421 |
+
t=num_frames,
|
| 422 |
+
h1=h1,
|
| 423 |
+
w1=w1,
|
| 424 |
+
)
|
| 425 |
+
mask_imgs = rearrange(
|
| 426 |
+
mask_imgs,
|
| 427 |
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
| 428 |
+
h=img_size,
|
| 429 |
+
w=img_size,
|
| 430 |
+
b=1,
|
| 431 |
+
c=len(bands),
|
| 432 |
+
t=num_frames,
|
| 433 |
+
h1=h1,
|
| 434 |
+
w1=w1,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Cut padded images back to original size
|
| 438 |
+
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
| 439 |
+
mask_imgs_full = mask_imgs[..., :original_h, :original_w]
|
| 440 |
+
batch_full = batch[..., :original_h, :original_w]
|
| 441 |
+
|
| 442 |
+
# Build output images
|
| 443 |
+
if rgb_outputs:
|
| 444 |
+
for d in meta_data:
|
| 445 |
+
d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
| 446 |
+
|
| 447 |
+
save_rgb_imgs(
|
| 448 |
+
batch_full[0, ...],
|
| 449 |
+
rec_imgs_full[0, ...],
|
| 450 |
+
mask_imgs_full[0, ...],
|
| 451 |
+
channels,
|
| 452 |
+
mean,
|
| 453 |
+
std,
|
| 454 |
+
output_dir,
|
| 455 |
+
meta_data,
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
for d in meta_data:
|
| 459 |
+
d.update(compress="lzw", nodata=0)
|
| 460 |
+
|
| 461 |
+
save_imgs(
|
| 462 |
+
rec_imgs_full[0, ...],
|
| 463 |
+
mask_imgs_full[0, ...],
|
| 464 |
+
mean,
|
| 465 |
+
std,
|
| 466 |
+
output_dir,
|
| 467 |
+
meta_data,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
print("Done!")
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
if __name__ == "__main__":
|
| 474 |
+
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
| 475 |
+
|
| 476 |
+
parser.add_argument(
|
| 477 |
+
"--data_files",
|
| 478 |
+
type=str,
|
| 479 |
+
nargs="+",
|
| 480 |
+
default=["examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif",
|
| 481 |
+
"examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif",
|
| 482 |
+
"examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif",
|
| 483 |
+
"examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif",
|
| 484 |
+
],
|
| 485 |
+
help="Path to the data files. Assumes multi-band files.",
|
| 486 |
+
)
|
| 487 |
+
parser.add_argument(
|
| 488 |
+
"--config_path",
|
| 489 |
+
"-c",
|
| 490 |
+
type=str,
|
| 491 |
+
default="config.json",
|
| 492 |
+
help="Path to json file containing model training parameters.",
|
| 493 |
+
)
|
| 494 |
+
parser.add_argument(
|
| 495 |
+
"--checkpoint",
|
| 496 |
+
type=str,
|
| 497 |
+
default="Prithvi_EO_V2_tiny.pt",
|
| 498 |
+
help="Path to a checkpoint file to load from.",
|
| 499 |
+
)
|
| 500 |
+
parser.add_argument(
|
| 501 |
+
"--output_dir",
|
| 502 |
+
type=str,
|
| 503 |
+
default="output",
|
| 504 |
+
help="Path to the directory where to save outputs.",
|
| 505 |
+
)
|
| 506 |
+
parser.add_argument(
|
| 507 |
+
"--mask_ratio",
|
| 508 |
+
default=0.75,
|
| 509 |
+
type=float,
|
| 510 |
+
help="Masking ratio (percentage of removed patches). "
|
| 511 |
+
"If None (default) use same value used for pretraining.",
|
| 512 |
+
)
|
| 513 |
+
parser.add_argument(
|
| 514 |
+
"--input_indices",
|
| 515 |
+
default=None,
|
| 516 |
+
type=int,
|
| 517 |
+
nargs="+",
|
| 518 |
+
help="0-based indices of channels to be selected from the input. By default takes all.",
|
| 519 |
+
)
|
| 520 |
+
parser.add_argument(
|
| 521 |
+
"--rgb_outputs",
|
| 522 |
+
action="store_true",
|
| 523 |
+
help="If present, output files will only contain RGB channels. "
|
| 524 |
+
"Otherwise, all bands will be saved.",
|
| 525 |
+
)
|
| 526 |
+
args = parser.parse_args()
|
| 527 |
+
|
| 528 |
+
main(**vars(args))
|
prithvi_mae.py
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) IBM Corp. 2024. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# References:
|
| 16 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 17 |
+
# transformers: https://github.com/huggingface/transformers
|
| 18 |
+
# --------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
import warnings
|
| 21 |
+
import logging
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
from einops import rearrange
|
| 26 |
+
from timm.layers import to_2tuple
|
| 27 |
+
from timm.models.vision_transformer import Block
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
| 33 |
+
"""
|
| 34 |
+
Create 3D sin/cos positional embeddings.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
embed_dim (int):
|
| 38 |
+
Embedding dimension.
|
| 39 |
+
grid_size (tuple[int, int, int] | list[int]):
|
| 40 |
+
The grid depth, height and width.
|
| 41 |
+
add_cls_token (bool, *optional*, defaults to False):
|
| 42 |
+
Whether or not to add a classification (CLS) token.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
(`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
|
| 46 |
+
(1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
assert embed_dim % 16 == 0
|
| 50 |
+
|
| 51 |
+
t_size, h_size, w_size = grid_size
|
| 52 |
+
|
| 53 |
+
w_embed_dim = embed_dim // 16 * 6
|
| 54 |
+
h_embed_dim = embed_dim // 16 * 6
|
| 55 |
+
t_embed_dim = embed_dim // 16 * 4
|
| 56 |
+
|
| 57 |
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
| 58 |
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
| 59 |
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
| 60 |
+
|
| 61 |
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
| 62 |
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
| 63 |
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
| 64 |
+
|
| 65 |
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
| 66 |
+
|
| 67 |
+
if add_cls_token:
|
| 68 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 69 |
+
return pos_embed
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 73 |
+
"""
|
| 74 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 75 |
+
"""
|
| 76 |
+
if embed_dim % 2 != 0:
|
| 77 |
+
raise ValueError("embed_dim must be even")
|
| 78 |
+
|
| 79 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 80 |
+
omega /= embed_dim / 2.0
|
| 81 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 82 |
+
|
| 83 |
+
pos = pos.reshape(-1) # (M,)
|
| 84 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 85 |
+
|
| 86 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 87 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 88 |
+
|
| 89 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 90 |
+
return emb
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
|
| 94 |
+
""" Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
|
| 95 |
+
|
| 96 |
+
embed_dim: output dimension for each position
|
| 97 |
+
pos: a list of positions to be encoded: size (M,) - must be float dtype!
|
| 98 |
+
out: (M, D)
|
| 99 |
+
"""
|
| 100 |
+
assert embed_dim % 2 == 0
|
| 101 |
+
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
| 102 |
+
|
| 103 |
+
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
|
| 104 |
+
omega /= embed_dim / 2.0
|
| 105 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 106 |
+
|
| 107 |
+
pos = pos.reshape(-1) # (M,)
|
| 108 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 109 |
+
|
| 110 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 111 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 112 |
+
|
| 113 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 114 |
+
|
| 115 |
+
return emb
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _init_weights(module):
|
| 119 |
+
"""Initialize the weights"""
|
| 120 |
+
if isinstance(module, nn.Linear):
|
| 121 |
+
nn.init.xavier_uniform_(module.weight)
|
| 122 |
+
if module.bias is not None:
|
| 123 |
+
module.bias.data.zero_()
|
| 124 |
+
elif isinstance(module, nn.LayerNorm):
|
| 125 |
+
module.bias.data.zero_()
|
| 126 |
+
module.weight.data.fill_(1.0)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _interpolate_pos_encoding(
|
| 130 |
+
pos_embed: torch.Tensor,
|
| 131 |
+
grid_size: tuple[int, int, int] | list[int],
|
| 132 |
+
patch_size: tuple[int, int, int] | list[int],
|
| 133 |
+
shape: tuple[int, int, int],
|
| 134 |
+
embed_dim: int,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Adapted from:
|
| 138 |
+
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
|
| 139 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
|
| 140 |
+
"""
|
| 141 |
+
t, h, w = shape
|
| 142 |
+
t_patches = t // patch_size[0]
|
| 143 |
+
h_patches = h // patch_size[1]
|
| 144 |
+
w_patches = w // patch_size[2]
|
| 145 |
+
|
| 146 |
+
if [t_patches, h_patches, w_patches] == grid_size:
|
| 147 |
+
# No interpolation needed
|
| 148 |
+
return pos_embed
|
| 149 |
+
if t_patches != grid_size[0]:
|
| 150 |
+
# Re-compute pos embedding to handle changed num_frames
|
| 151 |
+
new_grid_size = (t_patches, *grid_size[1:])
|
| 152 |
+
new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True)
|
| 153 |
+
new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
|
| 154 |
+
else:
|
| 155 |
+
new_grid_size = grid_size
|
| 156 |
+
new_pos_embed = pos_embed
|
| 157 |
+
|
| 158 |
+
class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
|
| 159 |
+
|
| 160 |
+
patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2)
|
| 161 |
+
|
| 162 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 163 |
+
patch_pos_embed,
|
| 164 |
+
size=(h_patches, w_patches),
|
| 165 |
+
mode='bicubic',
|
| 166 |
+
align_corners=True,
|
| 167 |
+
)
|
| 168 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
|
| 169 |
+
|
| 170 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class PatchEmbed(nn.Module):
|
| 174 |
+
"""3D version of timm.models.vision_transformer.PatchEmbed"""
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
input_size: tuple[int, int, int] = (1, 224, 224),
|
| 178 |
+
patch_size: tuple[int, int, int] = (1, 16, 16),
|
| 179 |
+
in_chans: int = 3,
|
| 180 |
+
embed_dim: int = 768,
|
| 181 |
+
norm_layer: nn.Module | None = None,
|
| 182 |
+
flatten: bool = True,
|
| 183 |
+
bias: bool = True,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.input_size = input_size
|
| 187 |
+
self.patch_size = patch_size
|
| 188 |
+
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
|
| 189 |
+
assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
|
| 190 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
| 191 |
+
self.flatten = flatten
|
| 192 |
+
|
| 193 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
| 194 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 195 |
+
|
| 196 |
+
def forward(self, x):
|
| 197 |
+
B, C, T, H, W = x.shape
|
| 198 |
+
|
| 199 |
+
if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
|
| 200 |
+
warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
|
| 201 |
+
f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
|
| 202 |
+
|
| 203 |
+
x = self.proj(x)
|
| 204 |
+
if self.flatten:
|
| 205 |
+
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
| 206 |
+
x = self.norm(x)
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class TemporalEncoder(nn.Module):
|
| 211 |
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.embed_dim = embed_dim
|
| 214 |
+
self.year_embed_dim = embed_dim // 2
|
| 215 |
+
self.julian_day_embed_dim = embed_dim - self.year_embed_dim
|
| 216 |
+
|
| 217 |
+
# If trainable, initialize scale with small number
|
| 218 |
+
if trainable_scale:
|
| 219 |
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
| 220 |
+
else:
|
| 221 |
+
self.register_buffer('scale', torch.ones(1))
|
| 222 |
+
|
| 223 |
+
def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
|
| 224 |
+
"""
|
| 225 |
+
temporal_coords: year and day-of-year info with shape (B, T, 2).
|
| 226 |
+
tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
|
| 227 |
+
repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
|
| 228 |
+
"""
|
| 229 |
+
shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
|
| 230 |
+
|
| 231 |
+
year = _get_1d_sincos_embed_from_grid_torch(
|
| 232 |
+
self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
|
| 233 |
+
julian_day = _get_1d_sincos_embed_from_grid_torch(
|
| 234 |
+
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
|
| 235 |
+
|
| 236 |
+
embedding = self.scale * torch.cat([year, julian_day], dim=-1)
|
| 237 |
+
|
| 238 |
+
if tokens_per_frame is not None:
|
| 239 |
+
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
|
| 240 |
+
|
| 241 |
+
return embedding # B, T*tokens_per_frame, embed_dim
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class LocationEncoder(nn.Module):
|
| 245 |
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.embed_dim = embed_dim
|
| 248 |
+
self.lat_embed_dim = embed_dim // 2
|
| 249 |
+
self.lon_embed_dim = embed_dim - self.lat_embed_dim
|
| 250 |
+
|
| 251 |
+
# If trainable, initialize scale with small number
|
| 252 |
+
if trainable_scale:
|
| 253 |
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
| 254 |
+
else:
|
| 255 |
+
self.register_buffer('scale', torch.ones(1))
|
| 256 |
+
|
| 257 |
+
def forward(self, location_coords: torch.Tensor):
|
| 258 |
+
"""
|
| 259 |
+
location_coords: lat and lon info with shape (B, 2).
|
| 260 |
+
"""
|
| 261 |
+
shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
|
| 262 |
+
|
| 263 |
+
lat = _get_1d_sincos_embed_from_grid_torch(
|
| 264 |
+
self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
|
| 265 |
+
lon = _get_1d_sincos_embed_from_grid_torch(
|
| 266 |
+
self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
|
| 267 |
+
|
| 268 |
+
embedding = self.scale * torch.cat([lat, lon], dim=-1)
|
| 269 |
+
|
| 270 |
+
return embedding # B, 1, embed_dim
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class PrithviViT(nn.Module):
|
| 274 |
+
""" Prithvi ViT Encoder"""
|
| 275 |
+
def __init__(self,
|
| 276 |
+
img_size: int | tuple[int, int] = 224,
|
| 277 |
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
| 278 |
+
num_frames: int = 1,
|
| 279 |
+
in_chans: int = 3,
|
| 280 |
+
embed_dim: int = 1024,
|
| 281 |
+
depth: int = 24,
|
| 282 |
+
num_heads: int = 16,
|
| 283 |
+
mlp_ratio: float = 4.,
|
| 284 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 285 |
+
coords_encoding: list[str] | None = None,
|
| 286 |
+
coords_scale_learn: bool = False,
|
| 287 |
+
drop_path: float = 0.,
|
| 288 |
+
** kwargs,
|
| 289 |
+
):
|
| 290 |
+
super().__init__()
|
| 291 |
+
|
| 292 |
+
self.in_chans = in_chans
|
| 293 |
+
self.num_frames = num_frames
|
| 294 |
+
self.embed_dim = embed_dim
|
| 295 |
+
self.img_size = to_2tuple(img_size)
|
| 296 |
+
if isinstance(patch_size, int):
|
| 297 |
+
patch_size = (1, patch_size, patch_size)
|
| 298 |
+
|
| 299 |
+
# 3D patch embedding
|
| 300 |
+
self.patch_embed = PatchEmbed(
|
| 301 |
+
input_size=(num_frames,) + self.img_size,
|
| 302 |
+
patch_size=patch_size,
|
| 303 |
+
in_chans=in_chans,
|
| 304 |
+
embed_dim=embed_dim,
|
| 305 |
+
)
|
| 306 |
+
self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
|
| 307 |
+
|
| 308 |
+
# Optional temporal and location embedding
|
| 309 |
+
coords_encoding = coords_encoding or []
|
| 310 |
+
self.temporal_encoding = 'time' in coords_encoding
|
| 311 |
+
self.location_encoding = 'location' in coords_encoding
|
| 312 |
+
if self.temporal_encoding:
|
| 313 |
+
assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
|
| 314 |
+
self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
|
| 315 |
+
if self.location_encoding:
|
| 316 |
+
self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
|
| 317 |
+
|
| 318 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 319 |
+
self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
|
| 320 |
+
|
| 321 |
+
# Transformer layers
|
| 322 |
+
self.blocks = []
|
| 323 |
+
for i in range(depth):
|
| 324 |
+
self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
| 325 |
+
drop_path=drop_path,))
|
| 326 |
+
self.blocks = nn.ModuleList(self.blocks)
|
| 327 |
+
|
| 328 |
+
self.norm = norm_layer(embed_dim)
|
| 329 |
+
|
| 330 |
+
self.initialize_weights()
|
| 331 |
+
|
| 332 |
+
def initialize_weights(self):
|
| 333 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 334 |
+
pos_embed = get_3d_sincos_pos_embed(
|
| 335 |
+
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
|
| 336 |
+
)
|
| 337 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 338 |
+
|
| 339 |
+
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
|
| 340 |
+
w = self.patch_embed.proj.weight.data
|
| 341 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 342 |
+
|
| 343 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 344 |
+
torch.nn.init.normal_(self.cls_token, std=0.02)
|
| 345 |
+
self.apply(_init_weights)
|
| 346 |
+
|
| 347 |
+
def random_masking(self, sequence, mask_ratio, noise=None):
|
| 348 |
+
"""
|
| 349 |
+
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
| 350 |
+
noise.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
|
| 354 |
+
mask_ratio (float): mask ratio to use.
|
| 355 |
+
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
| 356 |
+
mainly used for testing purposes to control randomness and maintain the reproducibility
|
| 357 |
+
"""
|
| 358 |
+
batch_size, seq_length, dim = sequence.shape
|
| 359 |
+
len_keep = int(seq_length * (1 - mask_ratio))
|
| 360 |
+
|
| 361 |
+
if noise is None:
|
| 362 |
+
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
|
| 363 |
+
|
| 364 |
+
# sort noise for each sample
|
| 365 |
+
ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
|
| 366 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
|
| 367 |
+
|
| 368 |
+
# keep the first subset
|
| 369 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 370 |
+
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
|
| 371 |
+
|
| 372 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 373 |
+
mask = torch.ones([batch_size, seq_length], device=sequence.device)
|
| 374 |
+
mask[:, :len_keep] = 0
|
| 375 |
+
# unshuffle to get the binary mask
|
| 376 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 377 |
+
|
| 378 |
+
return sequence_unmasked, mask, ids_restore
|
| 379 |
+
|
| 380 |
+
def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
|
| 381 |
+
|
| 382 |
+
pos_embed = _interpolate_pos_encoding(
|
| 383 |
+
pos_embed=self.pos_embed,
|
| 384 |
+
grid_size=self.patch_embed.grid_size,
|
| 385 |
+
patch_size=self.patch_embed.patch_size,
|
| 386 |
+
shape=sample_shape,
|
| 387 |
+
embed_dim=self.embed_dim,
|
| 388 |
+
)
|
| 389 |
+
return pos_embed
|
| 390 |
+
|
| 391 |
+
def forward(
|
| 392 |
+
self, x: torch.Tensor,
|
| 393 |
+
temporal_coords: None | torch.Tensor = None,
|
| 394 |
+
location_coords: None | torch.Tensor = None,
|
| 395 |
+
mask_ratio=0.75
|
| 396 |
+
):
|
| 397 |
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
| 398 |
+
# add time dim
|
| 399 |
+
x = x.unsqueeze(2)
|
| 400 |
+
sample_shape = x.shape[-3:]
|
| 401 |
+
|
| 402 |
+
# embed patches
|
| 403 |
+
x = self.patch_embed(x)
|
| 404 |
+
|
| 405 |
+
pos_embed = self.interpolate_pos_encoding(sample_shape)
|
| 406 |
+
# add pos embed w/o cls token
|
| 407 |
+
x = x + pos_embed[:, 1:, :]
|
| 408 |
+
|
| 409 |
+
if self.temporal_encoding and temporal_coords is not None:
|
| 410 |
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
| 411 |
+
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
|
| 412 |
+
x = x + temporal_encoding
|
| 413 |
+
if self.location_encoding and location_coords is not None:
|
| 414 |
+
location_encoding = self.location_embed_enc(location_coords)
|
| 415 |
+
x = x + location_encoding
|
| 416 |
+
|
| 417 |
+
# masking: length -> length * mask_ratio
|
| 418 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
| 419 |
+
|
| 420 |
+
# append cls token
|
| 421 |
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
| 422 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 423 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 424 |
+
|
| 425 |
+
# apply Transformer blocks
|
| 426 |
+
for block in self.blocks:
|
| 427 |
+
x = block(x)
|
| 428 |
+
x = self.norm(x)
|
| 429 |
+
|
| 430 |
+
return x, mask, ids_restore
|
| 431 |
+
|
| 432 |
+
def forward_features(
|
| 433 |
+
self,
|
| 434 |
+
x: torch.Tensor,
|
| 435 |
+
temporal_coords: None | torch.Tensor = None,
|
| 436 |
+
location_coords: None | torch.Tensor = None,
|
| 437 |
+
) -> list[torch.Tensor]:
|
| 438 |
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
| 439 |
+
# add time dim
|
| 440 |
+
x = x.unsqueeze(2)
|
| 441 |
+
sample_shape = x.shape[-3:]
|
| 442 |
+
|
| 443 |
+
# embed patches
|
| 444 |
+
x = self.patch_embed(x)
|
| 445 |
+
|
| 446 |
+
pos_embed = self.interpolate_pos_encoding(sample_shape)
|
| 447 |
+
# add pos embed w/o cls token
|
| 448 |
+
x = x + pos_embed[:, 1:, :]
|
| 449 |
+
|
| 450 |
+
if self.temporal_encoding and temporal_coords is not None:
|
| 451 |
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
| 452 |
+
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
|
| 453 |
+
x = x + temporal_encoding
|
| 454 |
+
if self.location_encoding and location_coords is not None:
|
| 455 |
+
location_encoding = self.location_embed_enc(location_coords)
|
| 456 |
+
x = x + location_encoding
|
| 457 |
+
|
| 458 |
+
# append cls token
|
| 459 |
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
| 460 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 461 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 462 |
+
|
| 463 |
+
# apply Transformer blocks
|
| 464 |
+
out = []
|
| 465 |
+
for block in self.blocks:
|
| 466 |
+
x = block(x)
|
| 467 |
+
out.append(x.clone())
|
| 468 |
+
|
| 469 |
+
x = self.norm(x)
|
| 470 |
+
out[-1] = x
|
| 471 |
+
return out
|
| 472 |
+
|
| 473 |
+
def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 474 |
+
out = []
|
| 475 |
+
effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
|
| 476 |
+
for x in features:
|
| 477 |
+
x_no_token = x[:, 1:, :]
|
| 478 |
+
number_of_tokens = x_no_token.shape[1]
|
| 479 |
+
tokens_per_timestep = number_of_tokens // effective_time_dim
|
| 480 |
+
h = int(np.sqrt(tokens_per_timestep))
|
| 481 |
+
encoded = rearrange(
|
| 482 |
+
x_no_token,
|
| 483 |
+
"batch (t h w) e -> batch (t e) h w",
|
| 484 |
+
e=self.embed_dim,
|
| 485 |
+
t=effective_time_dim,
|
| 486 |
+
h=h,
|
| 487 |
+
)
|
| 488 |
+
out.append(encoded)
|
| 489 |
+
return out
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class MAEDecoder(nn.Module):
|
| 493 |
+
""" Transformer Decoder used in the Prithvi MAE"""
|
| 494 |
+
def __init__(self,
|
| 495 |
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
| 496 |
+
grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
|
| 497 |
+
in_chans: int = 3,
|
| 498 |
+
encoder_embed_dim: int = 1024,
|
| 499 |
+
decoder_embed_dim: int = 512,
|
| 500 |
+
depth: int = 8,
|
| 501 |
+
num_heads: int = 16,
|
| 502 |
+
mlp_ratio: float = 4.,
|
| 503 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 504 |
+
coords_encoding: list[str] | None = None,
|
| 505 |
+
coords_scale_learn: bool = False,
|
| 506 |
+
):
|
| 507 |
+
super().__init__()
|
| 508 |
+
|
| 509 |
+
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
| 510 |
+
self.decoder_embed_dim = decoder_embed_dim
|
| 511 |
+
self.grid_size = grid_size
|
| 512 |
+
if isinstance(patch_size, int):
|
| 513 |
+
patch_size = (1, patch_size, patch_size)
|
| 514 |
+
self.patch_size = patch_size
|
| 515 |
+
self.num_frames = self.grid_size[0] * patch_size[0]
|
| 516 |
+
num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
| 517 |
+
|
| 518 |
+
# Optional temporal and location embedding
|
| 519 |
+
coords_encoding = coords_encoding or []
|
| 520 |
+
self.temporal_encoding = 'time' in coords_encoding
|
| 521 |
+
self.location_encoding = 'location' in coords_encoding
|
| 522 |
+
if self.temporal_encoding:
|
| 523 |
+
self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
|
| 524 |
+
if self.location_encoding:
|
| 525 |
+
self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
|
| 526 |
+
|
| 527 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
| 528 |
+
|
| 529 |
+
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
|
| 530 |
+
|
| 531 |
+
self.decoder_blocks = nn.ModuleList(
|
| 532 |
+
[Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 536 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim,
|
| 537 |
+
patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
|
| 538 |
+
bias=True)
|
| 539 |
+
|
| 540 |
+
self.initialize_weights()
|
| 541 |
+
|
| 542 |
+
def initialize_weights(self):
|
| 543 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 544 |
+
decoder_pos_embed = get_3d_sincos_pos_embed(
|
| 545 |
+
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
|
| 546 |
+
)
|
| 547 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 548 |
+
|
| 549 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 550 |
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
| 551 |
+
self.apply(_init_weights)
|
| 552 |
+
|
| 553 |
+
def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
|
| 554 |
+
|
| 555 |
+
pos_embed = _interpolate_pos_encoding(
|
| 556 |
+
pos_embed=self.decoder_pos_embed,
|
| 557 |
+
grid_size=self.grid_size,
|
| 558 |
+
patch_size=self.patch_size,
|
| 559 |
+
shape=sample_shape,
|
| 560 |
+
embed_dim=self.decoder_embed_dim,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
return pos_embed
|
| 564 |
+
|
| 565 |
+
def forward(
|
| 566 |
+
self,
|
| 567 |
+
hidden_states: torch.Tensor,
|
| 568 |
+
ids_restore: torch.Tensor,
|
| 569 |
+
temporal_coords: None | torch.Tensor = None,
|
| 570 |
+
location_coords: None | torch.Tensor = None,
|
| 571 |
+
input_size: list[int] = None,
|
| 572 |
+
):
|
| 573 |
+
# embed tokens
|
| 574 |
+
x = self.decoder_embed(hidden_states)
|
| 575 |
+
cls_token = x[:, :1, :]
|
| 576 |
+
|
| 577 |
+
# append mask tokens to sequence
|
| 578 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
| 579 |
+
x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
| 580 |
+
# unshuffle
|
| 581 |
+
x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device))
|
| 582 |
+
|
| 583 |
+
# add pos embed
|
| 584 |
+
decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:])
|
| 585 |
+
cls_token = cls_token + decoder_pos_embed[:, :1, :]
|
| 586 |
+
x = x + decoder_pos_embed[:, 1:, :]
|
| 587 |
+
|
| 588 |
+
if self.temporal_encoding and temporal_coords is not None:
|
| 589 |
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
| 590 |
+
temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
|
| 591 |
+
# Add temporal encoding w/o cls token
|
| 592 |
+
x = x + temporal_encoding
|
| 593 |
+
if self.location_encoding and location_coords is not None:
|
| 594 |
+
location_encoding = self.location_embed_dec(location_coords)
|
| 595 |
+
# Add location encoding w/o cls token
|
| 596 |
+
x = x + location_encoding
|
| 597 |
+
|
| 598 |
+
# append cls token
|
| 599 |
+
x = torch.cat([cls_token, x], dim=1)
|
| 600 |
+
|
| 601 |
+
# apply Transformer layers (blocks)
|
| 602 |
+
for block in self.decoder_blocks:
|
| 603 |
+
x = block(x)
|
| 604 |
+
x = self.decoder_norm(x)
|
| 605 |
+
|
| 606 |
+
# predictor projection
|
| 607 |
+
pred = self.decoder_pred(x)
|
| 608 |
+
|
| 609 |
+
# remove cls token
|
| 610 |
+
pred = pred[:, 1:, :]
|
| 611 |
+
|
| 612 |
+
return pred
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
class PrithviMAE(nn.Module):
|
| 616 |
+
""" Prithvi Masked Autoencoder"""
|
| 617 |
+
|
| 618 |
+
def __init__(self,
|
| 619 |
+
img_size: int | tuple[int, int] = 224,
|
| 620 |
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
| 621 |
+
num_frames: int = 4,
|
| 622 |
+
in_chans: int = 6,
|
| 623 |
+
embed_dim: int = 768,
|
| 624 |
+
depth: int = 12,
|
| 625 |
+
num_heads: int = 12,
|
| 626 |
+
decoder_embed_dim: int = 512,
|
| 627 |
+
decoder_depth: int = 8,
|
| 628 |
+
decoder_num_heads: int = 16,
|
| 629 |
+
mlp_ratio: float = 4.,
|
| 630 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 631 |
+
norm_pix_loss: bool = False,
|
| 632 |
+
coords_encoding: list[str] | None = None,
|
| 633 |
+
coords_scale_learn: bool = False,
|
| 634 |
+
drop_path: float = 0.,
|
| 635 |
+
mask_ratio: float = 0.75,
|
| 636 |
+
**kwargs,
|
| 637 |
+
):
|
| 638 |
+
super().__init__()
|
| 639 |
+
|
| 640 |
+
self.encoder = PrithviViT(
|
| 641 |
+
img_size=img_size,
|
| 642 |
+
num_frames=num_frames,
|
| 643 |
+
patch_size=patch_size,
|
| 644 |
+
in_chans=in_chans,
|
| 645 |
+
embed_dim=embed_dim,
|
| 646 |
+
depth=depth,
|
| 647 |
+
num_heads=num_heads,
|
| 648 |
+
mlp_ratio=mlp_ratio,
|
| 649 |
+
norm_layer=norm_layer,
|
| 650 |
+
coords_encoding=coords_encoding,
|
| 651 |
+
coords_scale_learn=coords_scale_learn,
|
| 652 |
+
drop_path=drop_path,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
self.decoder = MAEDecoder(
|
| 656 |
+
patch_size=patch_size,
|
| 657 |
+
grid_size=self.encoder.patch_embed.grid_size,
|
| 658 |
+
in_chans=in_chans,
|
| 659 |
+
encoder_embed_dim=embed_dim,
|
| 660 |
+
decoder_embed_dim=decoder_embed_dim,
|
| 661 |
+
depth=decoder_depth,
|
| 662 |
+
num_heads=decoder_num_heads,
|
| 663 |
+
mlp_ratio=mlp_ratio,
|
| 664 |
+
norm_layer=norm_layer,
|
| 665 |
+
coords_encoding=coords_encoding,
|
| 666 |
+
coords_scale_learn=coords_scale_learn,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
self.mask_ratio = mask_ratio
|
| 670 |
+
self.norm_pix_loss = norm_pix_loss
|
| 671 |
+
self.out_channels = self.encoder.out_channels
|
| 672 |
+
|
| 673 |
+
def patchify(self, pixel_values):
|
| 674 |
+
"""
|
| 675 |
+
Args:
|
| 676 |
+
pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
|
| 677 |
+
Pixel values.
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
torch.FloatTensor of shape
|
| 681 |
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
| 682 |
+
Patchified pixel values.
|
| 683 |
+
"""
|
| 684 |
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
| 685 |
+
num_channels = self.encoder.in_chans
|
| 686 |
+
|
| 687 |
+
# patchify
|
| 688 |
+
patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
|
| 689 |
+
c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
|
| 690 |
+
|
| 691 |
+
return patchified_pixel_values
|
| 692 |
+
|
| 693 |
+
def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None):
|
| 694 |
+
"""
|
| 695 |
+
Args:
|
| 696 |
+
patchified_pixel_values (`torch.FloatTensor` of shape
|
| 697 |
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
|
| 698 |
+
Patchified pixel values.
|
| 699 |
+
image_size (`tuple[int, int]`, *optional*):
|
| 700 |
+
Original image size.
|
| 701 |
+
|
| 702 |
+
Returns:
|
| 703 |
+
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
| 704 |
+
Pixel values.
|
| 705 |
+
"""
|
| 706 |
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
| 707 |
+
image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
|
| 708 |
+
original_height, original_width = image_size
|
| 709 |
+
num_patches_h = original_height // patch_size_h
|
| 710 |
+
num_patches_w = original_width // patch_size_w
|
| 711 |
+
num_channels = self.encoder.in_chans
|
| 712 |
+
|
| 713 |
+
pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
|
| 714 |
+
c=num_channels, h=num_patches_h, w=num_patches_w,
|
| 715 |
+
s=patch_size_t, p=patch_size_h, q=patch_size_w)
|
| 716 |
+
return pixel_values
|
| 717 |
+
|
| 718 |
+
def forward_loss(self, pixel_values, pred, mask):
|
| 719 |
+
"""
|
| 720 |
+
Args:
|
| 721 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
|
| 722 |
+
Pixel values.
|
| 723 |
+
pred (`torch.FloatTensor` of shape
|
| 724 |
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
| 725 |
+
Predicted pixel values.
|
| 726 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 727 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
`torch.FloatTensor`: Pixel reconstruction loss.
|
| 731 |
+
"""
|
| 732 |
+
target = self.patchify(pixel_values)
|
| 733 |
+
if self.norm_pix_loss:
|
| 734 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 735 |
+
var = target.var(dim=-1, keepdim=True)
|
| 736 |
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
| 737 |
+
|
| 738 |
+
loss = (pred - target) ** 2
|
| 739 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 740 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
| 741 |
+
return loss
|
| 742 |
+
|
| 743 |
+
def forward(
|
| 744 |
+
self,
|
| 745 |
+
pixel_values: torch.Tensor,
|
| 746 |
+
temporal_coords: None | torch.Tensor = None,
|
| 747 |
+
location_coords: None | torch.Tensor = None,
|
| 748 |
+
mask_ratio: float = None,
|
| 749 |
+
):
|
| 750 |
+
if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
|
| 751 |
+
# add time dim
|
| 752 |
+
pixel_values = pixel_values.unsqueeze(2)
|
| 753 |
+
|
| 754 |
+
mask_ratio = mask_ratio or self.mask_ratio
|
| 755 |
+
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
|
| 756 |
+
pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
|
| 757 |
+
loss = self.forward_loss(pixel_values, pred, mask)
|
| 758 |
+
return loss, pred, mask
|
| 759 |
+
|
| 760 |
+
def forward_features(
|
| 761 |
+
self,
|
| 762 |
+
x: torch.Tensor,
|
| 763 |
+
temporal_coords: None | torch.Tensor = None,
|
| 764 |
+
location_coords: None | torch.Tensor = None,
|
| 765 |
+
) -> list[torch.Tensor]:
|
| 766 |
+
return self.encoder.forward_features(x, temporal_coords, location_coords)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
timm
|
| 4 |
+
einops
|
| 5 |
+
rasterio
|