blumenstiel commited on
Commit
2b43c93
·
1 Parent(s): 78711ff

Add inference code

Browse files
assets/model_architecture.png ADDED

Git LFS Details

  • SHA256: 30d14e91bfaf1ec39a182254bb7cbdf3b98ae87d941b846c96d2042269c46cdb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.84 MB
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "prithvi_eo_v2_tiny",
3
+ "num_features": 192,
4
+ "pretrained_cfg": {
5
+ "img_size": 224,
6
+ "num_frames": 4,
7
+ "patch_size": [1, 16, 16],
8
+ "in_chans": 6,
9
+ "embed_dim": 192,
10
+ "depth": 12,
11
+ "num_heads": 3,
12
+ "decoder_embed_dim": 512,
13
+ "decoder_depth": 8,
14
+ "decoder_num_heads": 16,
15
+ "mlp_ratio": 4,
16
+ "coords_encoding": ["time", "location"],
17
+ "coords_scale_learn": true,
18
+ "mask_ratio": 0.75,
19
+ "norm_pix_loss": false,
20
+ "bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
21
+ "mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
22
+ "std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
23
+ "origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny",
24
+ "paper_ids": "arXiv:X.X"
25
+ }
26
+ }
examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif ADDED

Git LFS Details

  • SHA256: e34c1e8f6b69092bbf16f87da1a0c2337e8e53f28d172d8076e2efab292b795d
  • Pointer size: 132 Bytes
  • Size of remote file: 3.01 MB
examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif ADDED

Git LFS Details

  • SHA256: a4b24a34d83d25cac7dbcb7742db3f5b1e4849e5773172c1f0fc43c541bcd3fd
  • Pointer size: 132 Bytes
  • Size of remote file: 3.01 MB
examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif ADDED

Git LFS Details

  • SHA256: fce050cc821ebec2974e85cfe702c0f093d74caf12196adb7ee88c8a30773d4f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.01 MB
examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif ADDED

Git LFS Details

  • SHA256: f7f8c67c32027cd663f48226a5932c6c8119a55fb6e80a02636dea57f4733963
  • Pointer size: 132 Bytes
  • Size of remote file: 3.01 MB
inference.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ from typing import List, Union
5
+ import re
6
+ import datetime
7
+ import numpy as np
8
+ import pandas as pd
9
+ import rasterio
10
+ import torch
11
+ import yaml
12
+ from einops import rearrange
13
+
14
+ from functools import partial
15
+
16
+ from torch.distributed.checkpoint import state_dict
17
+
18
+ from prithvi_mae import PrithviMAE
19
+
20
+ NO_DATA = -9999
21
+ NO_DATA_FLOAT = 0.0001
22
+ OFFSET = 0
23
+ PERCENTILE = 99.9
24
+
25
+
26
+ def process_channel_group(orig_img, new_img, channels, mean, std):
27
+ """Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
28
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
29
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
30
+
31
+ Args:
32
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
33
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
34
+ channels: list of indices representing RGB channels.
35
+ mean: list of mean values for each band.
36
+ std: list of std values for each band.
37
+
38
+ Returns:
39
+ torch.Tensor with shape (num_channels, height, width) for original image
40
+ torch.Tensor with shape (num_channels, height, width) for the other image
41
+ """
42
+
43
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
44
+ std = torch.tensor(np.asarray(std)[:, None, None])
45
+ orig_img = orig_img[channels, ...]
46
+ valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
47
+ valid_mask[orig_img == NO_DATA_FLOAT] = False
48
+
49
+ # Back to original data range
50
+ orig_img = (orig_img * std[channels]) + mean[channels]
51
+ new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
52
+
53
+ # Rescale (enhancing contrast)
54
+ max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
55
+ min_value = OFFSET
56
+
57
+ orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
58
+ new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
59
+
60
+ # No data as zeros
61
+ orig_img[~valid_mask] = 0
62
+ new_img[~valid_mask] = 0
63
+
64
+ return orig_img, new_img
65
+
66
+
67
+ def read_geotiff(file_path: str):
68
+ """Read all bands from *file_path* and return image + meta info.
69
+
70
+ Args:
71
+ file_path: path to image file.
72
+
73
+ Returns:
74
+ np.ndarray with shape (bands, height, width)
75
+ meta info dict
76
+ """
77
+
78
+ with rasterio.open(file_path) as src:
79
+ img = src.read()
80
+ meta = src.meta
81
+ try:
82
+ coords = src.lnglat()
83
+ except:
84
+ # Cannot read coords
85
+ coords = None
86
+
87
+ return img, meta, coords
88
+
89
+
90
+ def save_geotiff(image, output_path: str, meta: dict):
91
+ """Save multi-band image in Geotiff file.
92
+
93
+ Args:
94
+ image: np.ndarray with shape (bands, height, width)
95
+ output_path: path where to save the image
96
+ meta: dict with meta info.
97
+ """
98
+
99
+ with rasterio.open(output_path, "w", **meta) as dest:
100
+ for i in range(image.shape[0]):
101
+ dest.write(image[i, :, :], i + 1)
102
+
103
+ return
104
+
105
+
106
+ def _convert_np_uint8(float_image: torch.Tensor):
107
+ image = float_image.numpy() * 255.0
108
+ image = image.astype(dtype=np.uint8)
109
+
110
+ return image
111
+
112
+
113
+ def load_example(
114
+ file_paths: List[str],
115
+ mean: List[float],
116
+ std: List[float],
117
+ indices: Union[list[int], None] = None,
118
+ ):
119
+ """Build an input example by loading images in *file_paths*.
120
+
121
+ Args:
122
+ file_paths: list of file paths .
123
+ mean: list containing mean values for each band in the images in *file_paths*.
124
+ std: list containing std values for each band in the images in *file_paths*.
125
+
126
+ Returns:
127
+ np.array containing created example
128
+ list of meta info for each image in *file_paths*
129
+ """
130
+
131
+ imgs = []
132
+ metas = []
133
+ temporal_coords = []
134
+ location_coords = []
135
+
136
+ for file in file_paths:
137
+ img, meta, coords = read_geotiff(file)
138
+
139
+ # Rescaling (don't normalize on nodata)
140
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
141
+ if indices is not None:
142
+ img = img[..., indices]
143
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
144
+
145
+ imgs.append(img)
146
+ metas.append(meta)
147
+ if coords is not None:
148
+ location_coords.append(coords)
149
+
150
+ try:
151
+ match = re.search(r'(\d{7,8}T\d{6})', file)
152
+ if match:
153
+ year = int(match.group(1)[:4])
154
+ julian_day = match.group(1).split('T')[0][4:]
155
+ if len(julian_day) == 3:
156
+ julian_day = int(julian_day)
157
+ else:
158
+ julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
159
+ temporal_coords.append([year, julian_day])
160
+ except Exception as e:
161
+ print(f'Could not extract timestamp for {file} ({e})')
162
+
163
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
164
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
165
+ imgs = np.expand_dims(imgs, axis=0) # add batch di
166
+
167
+ return imgs, temporal_coords, location_coords, metas
168
+
169
+
170
+ def run_model(
171
+ model: torch.nn.Module,
172
+ input_data: torch.Tensor,
173
+ temporal_coords: None | torch.Tensor,
174
+ location_coords: None | torch.Tensor,
175
+ mask_ratio: float,
176
+ device: torch.device,
177
+ ):
178
+ """Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
179
+
180
+ Args:
181
+ model: MAE model to run.
182
+ input_data: torch.Tensor with shape (B, C, T, H, W).
183
+ mask_ratio: mask ratio to use.
184
+ device: device where model should run.
185
+
186
+ Returns:
187
+ 3 torch.Tensor with shape (B, C, T, H, W).
188
+ """
189
+
190
+ with torch.no_grad():
191
+ x = input_data.to(device)
192
+
193
+ _, pred, mask = model(x, temporal_coords, location_coords, mask_ratio)
194
+
195
+ # Create mask and prediction images (un-patchify)
196
+ mask_img = (
197
+ model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
198
+ )
199
+ pred_img = model.unpatchify(pred).detach().cpu()
200
+
201
+ # Mix visible and predicted patches
202
+ rec_img = input_data.clone()
203
+ rec_img[mask_img == 1] = pred_img[
204
+ mask_img == 1
205
+ ] # binary mask: 0 is keep, 1 is remove
206
+
207
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
208
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
209
+
210
+ return rec_img, mask_img
211
+
212
+
213
+ def save_rgb_imgs(
214
+ input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
215
+ ):
216
+ """Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
217
+
218
+ Args:
219
+ input_img: input torch.Tensor with shape (C, T, H, W).
220
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
221
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
222
+ channels: list of indices representing RGB channels.
223
+ mean: list of mean values for each band.
224
+ std: list of std values for each band.
225
+ output_dir: directory where to save outputs.
226
+ meta_data: list of dicts with geotiff meta info.
227
+ """
228
+
229
+ for t in range(input_img.shape[1]):
230
+ rgb_orig, rgb_pred = process_channel_group(
231
+ orig_img=input_img[:, t, :, :],
232
+ new_img=rec_img[:, t, :, :],
233
+ channels=channels,
234
+ mean=mean,
235
+ std=std,
236
+ )
237
+
238
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
239
+
240
+ # Saving images
241
+
242
+ save_geotiff(
243
+ image=_convert_np_uint8(rgb_orig),
244
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
245
+ meta=meta_data[t],
246
+ )
247
+
248
+ save_geotiff(
249
+ image=_convert_np_uint8(rgb_pred),
250
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
251
+ meta=meta_data[t],
252
+ )
253
+
254
+ save_geotiff(
255
+ image=_convert_np_uint8(rgb_mask),
256
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
257
+ meta=meta_data[t],
258
+ )
259
+
260
+
261
+ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
262
+ """Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
263
+
264
+ Args:
265
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
266
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
267
+ mean: list of mean values for each band.
268
+ std: list of std values for each band.
269
+ output_dir: directory where to save outputs.
270
+ meta_data: list of dicts with geotiff meta info.
271
+ """
272
+
273
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
274
+ std = torch.tensor(np.asarray(std)[:, None, None])
275
+
276
+ for t in range(rec_img.shape[1]):
277
+ # Back to original data range
278
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
279
+
280
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
281
+
282
+ # Saving images
283
+
284
+ save_geotiff(
285
+ image=rec_img_t,
286
+ output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
287
+ meta=meta_data[t],
288
+ )
289
+
290
+ save_geotiff(
291
+ image=mask_img_t,
292
+ output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
293
+ meta=meta_data[t],
294
+ )
295
+
296
+
297
+ def main(
298
+ data_files: List[str],
299
+ config_path: str,
300
+ checkpoint: str,
301
+ output_dir: str,
302
+ rgb_outputs: bool,
303
+ mask_ratio: float = None,
304
+ input_indices: list[int] = None,
305
+ ):
306
+ os.makedirs(output_dir, exist_ok=True)
307
+
308
+ # Get parameters --------
309
+
310
+ import json
311
+ with open(config_path, "r") as f:
312
+ config = yaml.safe_load(f)['pretrained_cfg']
313
+
314
+ batch_size = 1
315
+ bands = config['bands']
316
+ num_frames = len(data_files)
317
+ mean = config['mean']
318
+ std = config['std']
319
+ coords_encoding = config['coords_encoding']
320
+ img_size = config['img_size']
321
+ mask_ratio = mask_ratio or config['mask_ratio']
322
+
323
+ print(
324
+ f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
325
+ )
326
+ if len(data_files) != 4:
327
+ print(
328
+ "The original model was trained for four time steps. \nResults with different numbers of time steps may vary"
329
+ )
330
+
331
+ if torch.cuda.is_available():
332
+ device = torch.device("cuda")
333
+ else:
334
+ device = torch.device("cpu")
335
+
336
+ print(f"Using {device} device.\n")
337
+
338
+ # Loading data ---------------------------------------------------------------------------------
339
+
340
+ input_data, temporal_coords, location_coords, meta_data = load_example(
341
+ file_paths=data_files, indices=input_indices, mean=mean, std=std
342
+ )
343
+
344
+ if len(temporal_coords) != num_frames and 'time' in coords_encoding:
345
+ coords_encoding.pop('time')
346
+ if not len(location_coords) and 'location' in coords_encoding:
347
+ coords_encoding.pop('location')
348
+
349
+ # Create model and load checkpoint -------------------------------------------------------------
350
+
351
+ config.update(
352
+ coords_encoding=coords_encoding,
353
+ num_frames=num_frames,
354
+ in_chans=len(bands),
355
+ )
356
+
357
+ model = PrithviMAE(**config)
358
+
359
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
360
+ print(f"\n--> Model has {total_params:,} parameters.\n")
361
+
362
+ model.to(device)
363
+
364
+ state_dict = torch.load(checkpoint, map_location=device, weights_only=True)
365
+ # discard fixed pos_embedding weight
366
+ for k in list(state_dict.keys()):
367
+ if k == 'encoder.pos_embed':
368
+ state_dict[k] = model.encoder.pos_embed
369
+ elif k == 'decoder.decoder_pos_embed':
370
+ state_dict[k] = model.decoder.decoder_pos_embed
371
+ model.load_state_dict(state_dict, strict=True)
372
+ print(f"Loaded checkpoint from {checkpoint}")
373
+
374
+ # Running model --------------------------------------------------------------------------------
375
+
376
+ model.eval()
377
+ channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
378
+
379
+ # Reflect pad if not divisible by img_size
380
+ original_h, original_w = input_data.shape[-2:]
381
+ pad_h = img_size - (original_h % img_size)
382
+ pad_w = img_size - (original_w % img_size)
383
+ input_data = np.pad(
384
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
385
+ )
386
+
387
+ # Build sliding window
388
+ batch = torch.tensor(input_data, device="cpu")
389
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
390
+ h1, w1 = windows.shape[3:5]
391
+ windows = rearrange(
392
+ windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
393
+ )
394
+
395
+ # Split into batches if number of windows > batch_size
396
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
397
+ windows = torch.tensor_split(windows, num_batches, dim=0)
398
+
399
+ temporal_coords = torch.Tensor(temporal_coords, device=device).unsqueeze(0)
400
+ location_coords = torch.Tensor(location_coords[0], device=device).unsqueeze(0)
401
+
402
+ # Run model
403
+ rec_imgs = []
404
+ mask_imgs = []
405
+ for x in windows:
406
+ rec_img, mask_img = run_model(model, x, temporal_coords, location_coords, mask_ratio, device)
407
+ rec_imgs.append(rec_img)
408
+ mask_imgs.append(mask_img)
409
+
410
+ rec_imgs = torch.concat(rec_imgs, dim=0)
411
+ mask_imgs = torch.concat(mask_imgs, dim=0)
412
+
413
+ # Build images from patches
414
+ rec_imgs = rearrange(
415
+ rec_imgs,
416
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
417
+ h=img_size,
418
+ w=img_size,
419
+ b=1,
420
+ c=len(bands),
421
+ t=num_frames,
422
+ h1=h1,
423
+ w1=w1,
424
+ )
425
+ mask_imgs = rearrange(
426
+ mask_imgs,
427
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
428
+ h=img_size,
429
+ w=img_size,
430
+ b=1,
431
+ c=len(bands),
432
+ t=num_frames,
433
+ h1=h1,
434
+ w1=w1,
435
+ )
436
+
437
+ # Cut padded images back to original size
438
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
439
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
440
+ batch_full = batch[..., :original_h, :original_w]
441
+
442
+ # Build output images
443
+ if rgb_outputs:
444
+ for d in meta_data:
445
+ d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
446
+
447
+ save_rgb_imgs(
448
+ batch_full[0, ...],
449
+ rec_imgs_full[0, ...],
450
+ mask_imgs_full[0, ...],
451
+ channels,
452
+ mean,
453
+ std,
454
+ output_dir,
455
+ meta_data,
456
+ )
457
+ else:
458
+ for d in meta_data:
459
+ d.update(compress="lzw", nodata=0)
460
+
461
+ save_imgs(
462
+ rec_imgs_full[0, ...],
463
+ mask_imgs_full[0, ...],
464
+ mean,
465
+ std,
466
+ output_dir,
467
+ meta_data,
468
+ )
469
+
470
+ print("Done!")
471
+
472
+
473
+ if __name__ == "__main__":
474
+ parser = argparse.ArgumentParser("MAE run inference", add_help=False)
475
+
476
+ parser.add_argument(
477
+ "--data_files",
478
+ type=str,
479
+ nargs="+",
480
+ default=["examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif",
481
+ "examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif",
482
+ "examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif",
483
+ "examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif",
484
+ ],
485
+ help="Path to the data files. Assumes multi-band files.",
486
+ )
487
+ parser.add_argument(
488
+ "--config_path",
489
+ "-c",
490
+ type=str,
491
+ default="config.json",
492
+ help="Path to json file containing model training parameters.",
493
+ )
494
+ parser.add_argument(
495
+ "--checkpoint",
496
+ type=str,
497
+ default="Prithvi_EO_V2_tiny.pt",
498
+ help="Path to a checkpoint file to load from.",
499
+ )
500
+ parser.add_argument(
501
+ "--output_dir",
502
+ type=str,
503
+ default="output",
504
+ help="Path to the directory where to save outputs.",
505
+ )
506
+ parser.add_argument(
507
+ "--mask_ratio",
508
+ default=0.75,
509
+ type=float,
510
+ help="Masking ratio (percentage of removed patches). "
511
+ "If None (default) use same value used for pretraining.",
512
+ )
513
+ parser.add_argument(
514
+ "--input_indices",
515
+ default=None,
516
+ type=int,
517
+ nargs="+",
518
+ help="0-based indices of channels to be selected from the input. By default takes all.",
519
+ )
520
+ parser.add_argument(
521
+ "--rgb_outputs",
522
+ action="store_true",
523
+ help="If present, output files will only contain RGB channels. "
524
+ "Otherwise, all bands will be saved.",
525
+ )
526
+ args = parser.parse_args()
527
+
528
+ main(**vars(args))
prithvi_mae.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) IBM Corp. 2024. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # References:
16
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
17
+ # transformers: https://github.com/huggingface/transformers
18
+ # --------------------------------------------------------
19
+
20
+ import warnings
21
+ import logging
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+ from einops import rearrange
26
+ from timm.layers import to_2tuple
27
+ from timm.models.vision_transformer import Block
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
+ """
34
+ Create 3D sin/cos positional embeddings.
35
+
36
+ Args:
37
+ embed_dim (int):
38
+ Embedding dimension.
39
+ grid_size (tuple[int, int, int] | list[int]):
40
+ The grid depth, height and width.
41
+ add_cls_token (bool, *optional*, defaults to False):
42
+ Whether or not to add a classification (CLS) token.
43
+
44
+ Returns:
45
+ (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
46
+ (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
47
+ """
48
+
49
+ assert embed_dim % 16 == 0
50
+
51
+ t_size, h_size, w_size = grid_size
52
+
53
+ w_embed_dim = embed_dim // 16 * 6
54
+ h_embed_dim = embed_dim // 16 * 6
55
+ t_embed_dim = embed_dim // 16 * 4
56
+
57
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
58
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
59
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
60
+
61
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
62
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
63
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
64
+
65
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
66
+
67
+ if add_cls_token:
68
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
69
+ return pos_embed
70
+
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
73
+ """
74
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
75
+ """
76
+ if embed_dim % 2 != 0:
77
+ raise ValueError("embed_dim must be even")
78
+
79
+ omega = np.arange(embed_dim // 2, dtype=float)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = np.sin(out) # (M, D/2)
87
+ emb_cos = np.cos(out) # (M, D/2)
88
+
89
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
90
+ return emb
91
+
92
+
93
+ def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
94
+ """ Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
95
+
96
+ embed_dim: output dimension for each position
97
+ pos: a list of positions to be encoded: size (M,) - must be float dtype!
98
+ out: (M, D)
99
+ """
100
+ assert embed_dim % 2 == 0
101
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
102
+
103
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
104
+ omega /= embed_dim / 2.0
105
+ omega = 1.0 / 10000**omega # (D/2,)
106
+
107
+ pos = pos.reshape(-1) # (M,)
108
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
109
+
110
+ emb_sin = torch.sin(out) # (M, D/2)
111
+ emb_cos = torch.cos(out) # (M, D/2)
112
+
113
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
114
+
115
+ return emb
116
+
117
+
118
+ def _init_weights(module):
119
+ """Initialize the weights"""
120
+ if isinstance(module, nn.Linear):
121
+ nn.init.xavier_uniform_(module.weight)
122
+ if module.bias is not None:
123
+ module.bias.data.zero_()
124
+ elif isinstance(module, nn.LayerNorm):
125
+ module.bias.data.zero_()
126
+ module.weight.data.fill_(1.0)
127
+
128
+
129
+ def _interpolate_pos_encoding(
130
+ pos_embed: torch.Tensor,
131
+ grid_size: tuple[int, int, int] | list[int],
132
+ patch_size: tuple[int, int, int] | list[int],
133
+ shape: tuple[int, int, int],
134
+ embed_dim: int,
135
+ ):
136
+ """
137
+ Adapted from:
138
+ - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
139
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
140
+ """
141
+ t, h, w = shape
142
+ t_patches = t // patch_size[0]
143
+ h_patches = h // patch_size[1]
144
+ w_patches = w // patch_size[2]
145
+
146
+ if [t_patches, h_patches, w_patches] == grid_size:
147
+ # No interpolation needed
148
+ return pos_embed
149
+ if t_patches != grid_size[0]:
150
+ # Re-compute pos embedding to handle changed num_frames
151
+ new_grid_size = (t_patches, *grid_size[1:])
152
+ new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True)
153
+ new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
154
+ else:
155
+ new_grid_size = grid_size
156
+ new_pos_embed = pos_embed
157
+
158
+ class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
159
+
160
+ patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2)
161
+
162
+ patch_pos_embed = nn.functional.interpolate(
163
+ patch_pos_embed,
164
+ size=(h_patches, w_patches),
165
+ mode='bicubic',
166
+ align_corners=True,
167
+ )
168
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
169
+
170
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
171
+
172
+
173
+ class PatchEmbed(nn.Module):
174
+ """3D version of timm.models.vision_transformer.PatchEmbed"""
175
+ def __init__(
176
+ self,
177
+ input_size: tuple[int, int, int] = (1, 224, 224),
178
+ patch_size: tuple[int, int, int] = (1, 16, 16),
179
+ in_chans: int = 3,
180
+ embed_dim: int = 768,
181
+ norm_layer: nn.Module | None = None,
182
+ flatten: bool = True,
183
+ bias: bool = True,
184
+ ):
185
+ super().__init__()
186
+ self.input_size = input_size
187
+ self.patch_size = patch_size
188
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
189
+ assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
190
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
191
+ self.flatten = flatten
192
+
193
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
194
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
195
+
196
+ def forward(self, x):
197
+ B, C, T, H, W = x.shape
198
+
199
+ if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
200
+ warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
201
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
202
+
203
+ x = self.proj(x)
204
+ if self.flatten:
205
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
206
+ x = self.norm(x)
207
+ return x
208
+
209
+
210
+ class TemporalEncoder(nn.Module):
211
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
212
+ super().__init__()
213
+ self.embed_dim = embed_dim
214
+ self.year_embed_dim = embed_dim // 2
215
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
216
+
217
+ # If trainable, initialize scale with small number
218
+ if trainable_scale:
219
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
220
+ else:
221
+ self.register_buffer('scale', torch.ones(1))
222
+
223
+ def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
224
+ """
225
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
226
+ tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
227
+ repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
228
+ """
229
+ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
230
+
231
+ year = _get_1d_sincos_embed_from_grid_torch(
232
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
233
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
234
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
235
+
236
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
237
+
238
+ if tokens_per_frame is not None:
239
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
240
+
241
+ return embedding # B, T*tokens_per_frame, embed_dim
242
+
243
+
244
+ class LocationEncoder(nn.Module):
245
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
246
+ super().__init__()
247
+ self.embed_dim = embed_dim
248
+ self.lat_embed_dim = embed_dim // 2
249
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
250
+
251
+ # If trainable, initialize scale with small number
252
+ if trainable_scale:
253
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
254
+ else:
255
+ self.register_buffer('scale', torch.ones(1))
256
+
257
+ def forward(self, location_coords: torch.Tensor):
258
+ """
259
+ location_coords: lat and lon info with shape (B, 2).
260
+ """
261
+ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
262
+
263
+ lat = _get_1d_sincos_embed_from_grid_torch(
264
+ self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
265
+ lon = _get_1d_sincos_embed_from_grid_torch(
266
+ self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
267
+
268
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
269
+
270
+ return embedding # B, 1, embed_dim
271
+
272
+
273
+ class PrithviViT(nn.Module):
274
+ """ Prithvi ViT Encoder"""
275
+ def __init__(self,
276
+ img_size: int | tuple[int, int] = 224,
277
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
278
+ num_frames: int = 1,
279
+ in_chans: int = 3,
280
+ embed_dim: int = 1024,
281
+ depth: int = 24,
282
+ num_heads: int = 16,
283
+ mlp_ratio: float = 4.,
284
+ norm_layer: nn.Module = nn.LayerNorm,
285
+ coords_encoding: list[str] | None = None,
286
+ coords_scale_learn: bool = False,
287
+ drop_path: float = 0.,
288
+ ** kwargs,
289
+ ):
290
+ super().__init__()
291
+
292
+ self.in_chans = in_chans
293
+ self.num_frames = num_frames
294
+ self.embed_dim = embed_dim
295
+ self.img_size = to_2tuple(img_size)
296
+ if isinstance(patch_size, int):
297
+ patch_size = (1, patch_size, patch_size)
298
+
299
+ # 3D patch embedding
300
+ self.patch_embed = PatchEmbed(
301
+ input_size=(num_frames,) + self.img_size,
302
+ patch_size=patch_size,
303
+ in_chans=in_chans,
304
+ embed_dim=embed_dim,
305
+ )
306
+ self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
307
+
308
+ # Optional temporal and location embedding
309
+ coords_encoding = coords_encoding or []
310
+ self.temporal_encoding = 'time' in coords_encoding
311
+ self.location_encoding = 'location' in coords_encoding
312
+ if self.temporal_encoding:
313
+ assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
314
+ self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
315
+ if self.location_encoding:
316
+ self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
317
+
318
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
319
+ self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
320
+
321
+ # Transformer layers
322
+ self.blocks = []
323
+ for i in range(depth):
324
+ self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
325
+ drop_path=drop_path,))
326
+ self.blocks = nn.ModuleList(self.blocks)
327
+
328
+ self.norm = norm_layer(embed_dim)
329
+
330
+ self.initialize_weights()
331
+
332
+ def initialize_weights(self):
333
+ # initialize (and freeze) position embeddings by sin-cos embedding
334
+ pos_embed = get_3d_sincos_pos_embed(
335
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
336
+ )
337
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
338
+
339
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
340
+ w = self.patch_embed.proj.weight.data
341
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
342
+
343
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
344
+ torch.nn.init.normal_(self.cls_token, std=0.02)
345
+ self.apply(_init_weights)
346
+
347
+ def random_masking(self, sequence, mask_ratio, noise=None):
348
+ """
349
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
350
+ noise.
351
+
352
+ Args:
353
+ sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
354
+ mask_ratio (float): mask ratio to use.
355
+ noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
356
+ mainly used for testing purposes to control randomness and maintain the reproducibility
357
+ """
358
+ batch_size, seq_length, dim = sequence.shape
359
+ len_keep = int(seq_length * (1 - mask_ratio))
360
+
361
+ if noise is None:
362
+ noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
363
+
364
+ # sort noise for each sample
365
+ ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
366
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
367
+
368
+ # keep the first subset
369
+ ids_keep = ids_shuffle[:, :len_keep]
370
+ sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
371
+
372
+ # generate the binary mask: 0 is keep, 1 is remove
373
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
374
+ mask[:, :len_keep] = 0
375
+ # unshuffle to get the binary mask
376
+ mask = torch.gather(mask, dim=1, index=ids_restore)
377
+
378
+ return sequence_unmasked, mask, ids_restore
379
+
380
+ def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
381
+
382
+ pos_embed = _interpolate_pos_encoding(
383
+ pos_embed=self.pos_embed,
384
+ grid_size=self.patch_embed.grid_size,
385
+ patch_size=self.patch_embed.patch_size,
386
+ shape=sample_shape,
387
+ embed_dim=self.embed_dim,
388
+ )
389
+ return pos_embed
390
+
391
+ def forward(
392
+ self, x: torch.Tensor,
393
+ temporal_coords: None | torch.Tensor = None,
394
+ location_coords: None | torch.Tensor = None,
395
+ mask_ratio=0.75
396
+ ):
397
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
398
+ # add time dim
399
+ x = x.unsqueeze(2)
400
+ sample_shape = x.shape[-3:]
401
+
402
+ # embed patches
403
+ x = self.patch_embed(x)
404
+
405
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
406
+ # add pos embed w/o cls token
407
+ x = x + pos_embed[:, 1:, :]
408
+
409
+ if self.temporal_encoding and temporal_coords is not None:
410
+ num_tokens_per_frame = x.shape[1] // self.num_frames
411
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
412
+ x = x + temporal_encoding
413
+ if self.location_encoding and location_coords is not None:
414
+ location_encoding = self.location_embed_enc(location_coords)
415
+ x = x + location_encoding
416
+
417
+ # masking: length -> length * mask_ratio
418
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
419
+
420
+ # append cls token
421
+ cls_token = self.cls_token + pos_embed[:, :1, :]
422
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
423
+ x = torch.cat((cls_tokens, x), dim=1)
424
+
425
+ # apply Transformer blocks
426
+ for block in self.blocks:
427
+ x = block(x)
428
+ x = self.norm(x)
429
+
430
+ return x, mask, ids_restore
431
+
432
+ def forward_features(
433
+ self,
434
+ x: torch.Tensor,
435
+ temporal_coords: None | torch.Tensor = None,
436
+ location_coords: None | torch.Tensor = None,
437
+ ) -> list[torch.Tensor]:
438
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
439
+ # add time dim
440
+ x = x.unsqueeze(2)
441
+ sample_shape = x.shape[-3:]
442
+
443
+ # embed patches
444
+ x = self.patch_embed(x)
445
+
446
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
447
+ # add pos embed w/o cls token
448
+ x = x + pos_embed[:, 1:, :]
449
+
450
+ if self.temporal_encoding and temporal_coords is not None:
451
+ num_tokens_per_frame = x.shape[1] // self.num_frames
452
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
453
+ x = x + temporal_encoding
454
+ if self.location_encoding and location_coords is not None:
455
+ location_encoding = self.location_embed_enc(location_coords)
456
+ x = x + location_encoding
457
+
458
+ # append cls token
459
+ cls_token = self.cls_token + pos_embed[:, :1, :]
460
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
461
+ x = torch.cat((cls_tokens, x), dim=1)
462
+
463
+ # apply Transformer blocks
464
+ out = []
465
+ for block in self.blocks:
466
+ x = block(x)
467
+ out.append(x.clone())
468
+
469
+ x = self.norm(x)
470
+ out[-1] = x
471
+ return out
472
+
473
+ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
474
+ out = []
475
+ effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
476
+ for x in features:
477
+ x_no_token = x[:, 1:, :]
478
+ number_of_tokens = x_no_token.shape[1]
479
+ tokens_per_timestep = number_of_tokens // effective_time_dim
480
+ h = int(np.sqrt(tokens_per_timestep))
481
+ encoded = rearrange(
482
+ x_no_token,
483
+ "batch (t h w) e -> batch (t e) h w",
484
+ e=self.embed_dim,
485
+ t=effective_time_dim,
486
+ h=h,
487
+ )
488
+ out.append(encoded)
489
+ return out
490
+
491
+
492
+ class MAEDecoder(nn.Module):
493
+ """ Transformer Decoder used in the Prithvi MAE"""
494
+ def __init__(self,
495
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
496
+ grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
497
+ in_chans: int = 3,
498
+ encoder_embed_dim: int = 1024,
499
+ decoder_embed_dim: int = 512,
500
+ depth: int = 8,
501
+ num_heads: int = 16,
502
+ mlp_ratio: float = 4.,
503
+ norm_layer: nn.Module = nn.LayerNorm,
504
+ coords_encoding: list[str] | None = None,
505
+ coords_scale_learn: bool = False,
506
+ ):
507
+ super().__init__()
508
+
509
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
510
+ self.decoder_embed_dim = decoder_embed_dim
511
+ self.grid_size = grid_size
512
+ if isinstance(patch_size, int):
513
+ patch_size = (1, patch_size, patch_size)
514
+ self.patch_size = patch_size
515
+ self.num_frames = self.grid_size[0] * patch_size[0]
516
+ num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
517
+
518
+ # Optional temporal and location embedding
519
+ coords_encoding = coords_encoding or []
520
+ self.temporal_encoding = 'time' in coords_encoding
521
+ self.location_encoding = 'location' in coords_encoding
522
+ if self.temporal_encoding:
523
+ self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
524
+ if self.location_encoding:
525
+ self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
526
+
527
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
528
+
529
+ self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
530
+
531
+ self.decoder_blocks = nn.ModuleList(
532
+ [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
533
+ )
534
+
535
+ self.decoder_norm = norm_layer(decoder_embed_dim)
536
+ self.decoder_pred = nn.Linear(decoder_embed_dim,
537
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
538
+ bias=True)
539
+
540
+ self.initialize_weights()
541
+
542
+ def initialize_weights(self):
543
+ # initialize (and freeze) position embeddings by sin-cos embedding
544
+ decoder_pos_embed = get_3d_sincos_pos_embed(
545
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
546
+ )
547
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
548
+
549
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
550
+ torch.nn.init.normal_(self.mask_token, std=0.02)
551
+ self.apply(_init_weights)
552
+
553
+ def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
554
+
555
+ pos_embed = _interpolate_pos_encoding(
556
+ pos_embed=self.decoder_pos_embed,
557
+ grid_size=self.grid_size,
558
+ patch_size=self.patch_size,
559
+ shape=sample_shape,
560
+ embed_dim=self.decoder_embed_dim,
561
+ )
562
+
563
+ return pos_embed
564
+
565
+ def forward(
566
+ self,
567
+ hidden_states: torch.Tensor,
568
+ ids_restore: torch.Tensor,
569
+ temporal_coords: None | torch.Tensor = None,
570
+ location_coords: None | torch.Tensor = None,
571
+ input_size: list[int] = None,
572
+ ):
573
+ # embed tokens
574
+ x = self.decoder_embed(hidden_states)
575
+ cls_token = x[:, :1, :]
576
+
577
+ # append mask tokens to sequence
578
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
579
+ x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
580
+ # unshuffle
581
+ x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device))
582
+
583
+ # add pos embed
584
+ decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:])
585
+ cls_token = cls_token + decoder_pos_embed[:, :1, :]
586
+ x = x + decoder_pos_embed[:, 1:, :]
587
+
588
+ if self.temporal_encoding and temporal_coords is not None:
589
+ num_tokens_per_frame = x.shape[1] // self.num_frames
590
+ temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
591
+ # Add temporal encoding w/o cls token
592
+ x = x + temporal_encoding
593
+ if self.location_encoding and location_coords is not None:
594
+ location_encoding = self.location_embed_dec(location_coords)
595
+ # Add location encoding w/o cls token
596
+ x = x + location_encoding
597
+
598
+ # append cls token
599
+ x = torch.cat([cls_token, x], dim=1)
600
+
601
+ # apply Transformer layers (blocks)
602
+ for block in self.decoder_blocks:
603
+ x = block(x)
604
+ x = self.decoder_norm(x)
605
+
606
+ # predictor projection
607
+ pred = self.decoder_pred(x)
608
+
609
+ # remove cls token
610
+ pred = pred[:, 1:, :]
611
+
612
+ return pred
613
+
614
+
615
+ class PrithviMAE(nn.Module):
616
+ """ Prithvi Masked Autoencoder"""
617
+
618
+ def __init__(self,
619
+ img_size: int | tuple[int, int] = 224,
620
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
621
+ num_frames: int = 4,
622
+ in_chans: int = 6,
623
+ embed_dim: int = 768,
624
+ depth: int = 12,
625
+ num_heads: int = 12,
626
+ decoder_embed_dim: int = 512,
627
+ decoder_depth: int = 8,
628
+ decoder_num_heads: int = 16,
629
+ mlp_ratio: float = 4.,
630
+ norm_layer: nn.Module = nn.LayerNorm,
631
+ norm_pix_loss: bool = False,
632
+ coords_encoding: list[str] | None = None,
633
+ coords_scale_learn: bool = False,
634
+ drop_path: float = 0.,
635
+ mask_ratio: float = 0.75,
636
+ **kwargs,
637
+ ):
638
+ super().__init__()
639
+
640
+ self.encoder = PrithviViT(
641
+ img_size=img_size,
642
+ num_frames=num_frames,
643
+ patch_size=patch_size,
644
+ in_chans=in_chans,
645
+ embed_dim=embed_dim,
646
+ depth=depth,
647
+ num_heads=num_heads,
648
+ mlp_ratio=mlp_ratio,
649
+ norm_layer=norm_layer,
650
+ coords_encoding=coords_encoding,
651
+ coords_scale_learn=coords_scale_learn,
652
+ drop_path=drop_path,
653
+ )
654
+
655
+ self.decoder = MAEDecoder(
656
+ patch_size=patch_size,
657
+ grid_size=self.encoder.patch_embed.grid_size,
658
+ in_chans=in_chans,
659
+ encoder_embed_dim=embed_dim,
660
+ decoder_embed_dim=decoder_embed_dim,
661
+ depth=decoder_depth,
662
+ num_heads=decoder_num_heads,
663
+ mlp_ratio=mlp_ratio,
664
+ norm_layer=norm_layer,
665
+ coords_encoding=coords_encoding,
666
+ coords_scale_learn=coords_scale_learn,
667
+ )
668
+
669
+ self.mask_ratio = mask_ratio
670
+ self.norm_pix_loss = norm_pix_loss
671
+ self.out_channels = self.encoder.out_channels
672
+
673
+ def patchify(self, pixel_values):
674
+ """
675
+ Args:
676
+ pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
677
+ Pixel values.
678
+
679
+ Returns:
680
+ torch.FloatTensor of shape
681
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
682
+ Patchified pixel values.
683
+ """
684
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
685
+ num_channels = self.encoder.in_chans
686
+
687
+ # patchify
688
+ patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
689
+ c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
690
+
691
+ return patchified_pixel_values
692
+
693
+ def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None):
694
+ """
695
+ Args:
696
+ patchified_pixel_values (`torch.FloatTensor` of shape
697
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
698
+ Patchified pixel values.
699
+ image_size (`tuple[int, int]`, *optional*):
700
+ Original image size.
701
+
702
+ Returns:
703
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
704
+ Pixel values.
705
+ """
706
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
707
+ image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
708
+ original_height, original_width = image_size
709
+ num_patches_h = original_height // patch_size_h
710
+ num_patches_w = original_width // patch_size_w
711
+ num_channels = self.encoder.in_chans
712
+
713
+ pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
714
+ c=num_channels, h=num_patches_h, w=num_patches_w,
715
+ s=patch_size_t, p=patch_size_h, q=patch_size_w)
716
+ return pixel_values
717
+
718
+ def forward_loss(self, pixel_values, pred, mask):
719
+ """
720
+ Args:
721
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
722
+ Pixel values.
723
+ pred (`torch.FloatTensor` of shape
724
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
725
+ Predicted pixel values.
726
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
727
+ Tensor indicating which patches are masked (1) and which are not (0).
728
+
729
+ Returns:
730
+ `torch.FloatTensor`: Pixel reconstruction loss.
731
+ """
732
+ target = self.patchify(pixel_values)
733
+ if self.norm_pix_loss:
734
+ mean = target.mean(dim=-1, keepdim=True)
735
+ var = target.var(dim=-1, keepdim=True)
736
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
737
+
738
+ loss = (pred - target) ** 2
739
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
740
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
741
+ return loss
742
+
743
+ def forward(
744
+ self,
745
+ pixel_values: torch.Tensor,
746
+ temporal_coords: None | torch.Tensor = None,
747
+ location_coords: None | torch.Tensor = None,
748
+ mask_ratio: float = None,
749
+ ):
750
+ if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
751
+ # add time dim
752
+ pixel_values = pixel_values.unsqueeze(2)
753
+
754
+ mask_ratio = mask_ratio or self.mask_ratio
755
+ latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
756
+ pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
757
+ loss = self.forward_loss(pixel_values, pred, mask)
758
+ return loss, pred, mask
759
+
760
+ def forward_features(
761
+ self,
762
+ x: torch.Tensor,
763
+ temporal_coords: None | torch.Tensor = None,
764
+ location_coords: None | torch.Tensor = None,
765
+ ) -> list[torch.Tensor]:
766
+ return self.encoder.forward_features(x, temporal_coords, location_coords)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ einops
5
+ rasterio