# finetune_wms_dataset.py (replacement) import os import cv2 import torch import numpy as np import random from torch.utils.data import Dataset from typing import List, Tuple, Dict IMG_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff', '.webp') def _natural_key(s: str): import re return [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', s)] def list_images_grouped_by_folder(root_dir: str, group_level: int = 1) -> Dict[str, List[str]]: """ Walk root_dir and group images by parent/grandparent folder. group_level=1 -> immediate parent folder """ groups: Dict[str, List[str]] = {} for r, _, files in os.walk(root_dir): for f in files: if not f.lower().endswith(IMG_EXTS): continue p = os.path.join(r, f) key = p for _ in range(group_level): key = os.path.dirname(key) key = os.path.normpath(os.path.abspath(key)) groups.setdefault(key, []).append(p) # natural sort within each group for k in list(groups.keys()): groups[k].sort(key=_natural_key) return groups def build_triplets_within_groups(groups: Dict[str, List[str]], stride: int = 1) -> List[Tuple[str, str, str]]: """Sliding-window triplets inside each group; skip groups with <3 images.""" triplets = [] for grp, files in groups.items(): n = len(files) if n < 3: continue for i in range(0, n - 2, stride): triplets.append((files[i], files[i+1], files[i+2])) return triplets def _pad_reflect(img: np.ndarray, top:int, bottom:int, left:int, right:int) -> np.ndarray: return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_REFLECT_101) def _read_image(path: str) -> np.ndarray: img = cv2.imread(path, cv2.IMREAD_UNCHANGED) if img is None: raise RuntimeError(f"cv2 failed to read: {path}") if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img class WMSDataset(Dataset): def __init__(self, root: str, split: str = 'train', crop_size: int = 256, stride: int = 1, color_mode: str = 'rgb', augment: bool = True, group_level: int = 1, mode: str = 'pad', multiple_of: int = 32): """ Compatible replacement dataset that builds triplets per-folder and is robust to mixed image sizes. - group_level: 1 = parent folder, 2 = grandparent, etc. - mode: 'pad' (default), 'resize_shorter', or 'scale_and_crop' - multiple_of: pad final crops to this multiple (32 recommended) """ assert split in ('train', 'val') assert mode in ('pad', 'resize_shorter', 'scale_and_crop') self.root = root self.crop_size = int(crop_size) self.stride = int(stride) self.augment = augment if split == 'train' else False self.color_mode = color_mode self.group_level = int(group_level) self.mode = mode self.multiple_of = int(multiple_of) groups = list_images_grouped_by_folder(root, group_level=self.group_level) all_trips = build_triplets_within_groups(groups, stride=self.stride) # stable split by hashing middle frame path self.samples: List[Tuple[str,str,str]] = [] for trip in all_trips: mid = trip[1] h = abs(hash(mid)) % 100 if split == 'train' and h < 90: self.samples.append(trip) elif split == 'val' and h >= 90: self.samples.append(trip) def __len__(self): return len(self.samples) def _ensure_min_size_and_mode(self, img: np.ndarray, target_h: int, target_w: int) -> np.ndarray: h,w = img.shape[:2] if self.mode == 'resize_shorter': short = min(h,w) if short < self.crop_size: scale = self.crop_size / max(1, short) new_w = int(round(w * scale)) new_h = int(round(h * scale)) img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) h,w = img.shape[:2] elif self.mode == 'scale_and_crop': if h < self.crop_size or w < self.crop_size: scale = max(self.crop_size / max(1,h), self.crop_size / max(1,w)) new_w = int(round(w * scale)) new_h = int(round(h * scale)) img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) h,w = img.shape[:2] pad_h = max(0, target_h - h) pad_w = max(0, target_w - w) if pad_h > 0 or pad_w > 0: top = pad_h // 2 bottom = pad_h - top left = pad_w // 2 right = pad_w - left img = _pad_reflect(img, top, bottom, left, right) return img def _pad_to_multiple(self, img: np.ndarray, multiple:int) -> np.ndarray: h,w = img.shape[:2] ph = ((h + multiple - 1)//multiple) * multiple pw = ((w + multiple - 1)//multiple) * multiple if ph == h and pw == w: return img top = 0 left = 0 bottom = ph - h right = pw - w return _pad_reflect(img, top, bottom, left, right) def __getitem__(self, idx): p0, pt, p1 = self.samples[idx] # 1) read images I0 = _read_image(p0) It = _read_image(pt) I1 = _read_image(p1) # 2) pick a safe base size (>= crop_size) so padding logic is stable base_h = max(self.crop_size, I0.shape[0], It.shape[0], I1.shape[0]) base_w = max(self.crop_size, I0.shape[1], It.shape[1], I1.shape[1]) I0 = self._ensure_min_size_and_mode(I0, base_h, base_w) It = self._ensure_min_size_and_mode(It, base_h, base_w) I1 = self._ensure_min_size_and_mode(I1, base_h, base_w) H, W = I0.shape[:2] # 3) crop (random for train, center for val) if self.crop_size > 0 and H >= self.crop_size and W >= self.crop_size: if self.augment: y = random.randint(0, H - self.crop_size) x = random.randint(0, W - self.crop_size) else: y = (H - self.crop_size) // 2 x = (W - self.crop_size) // 2 I0 = I0[y:y+self.crop_size, x:x+self.crop_size] It = It[y:y+self.crop_size, x:x+self.crop_size] I1 = I1[y:y+self.crop_size, x:x+self.crop_size] else: # fallback: center-resize to crop_size if self.crop_size > 0: I0 = cv2.resize(I0, (self.crop_size, self.crop_size), interpolation=cv2.INTER_LINEAR) It = cv2.resize(It, (self.crop_size, self.crop_size), interpolation=cv2.INTER_LINEAR) I1 = cv2.resize(I1, (self.crop_size, self.crop_size), interpolation=cv2.INTER_LINEAR) # 4) augment flips if self.augment: if random.random() < 0.5: I0 = np.fliplr(I0).copy(); It = np.fliplr(It).copy(); I1 = np.fliplr(I1).copy() if random.random() < 0.5: I0 = np.flipud(I0).copy(); It = np.flipud(It).copy(); I1 = np.flipud(I1).copy() # 5) pad to multiple_of if self.multiple_of and self.multiple_of > 1: I0 = self._pad_to_multiple(I0, self.multiple_of) It = self._pad_to_multiple(It, self.multiple_of) I1 = self._pad_to_multiple(I1, self.multiple_of) # 6) to float [0,1] and CHW tensor I0 = I0.astype(np.float32) / 255.0 It = It.astype(np.float32) / 255.0 I1 = I1.astype(np.float32) / 255.0 def to_tensor(img): return torch.from_numpy(img.transpose(2,0,1)).float() T0 = to_tensor(I0) Tt = to_tensor(It) T1 = to_tensor(I1) # sanity check if T0.shape[1] == 0 or Tt.shape[1] == 0 or T1.shape[1] == 0: raise RuntimeError(f"Zero-sized tensor after processing idx {idx} paths: {p0}, {pt}, {p1}") sample = torch.cat([T0, T1, Tt], dim=0) # (9, H, W) t = torch.tensor([0.5], dtype=torch.float32) return sample, t # optional utility: print counts per folder (useful for debugging) def folder_image_counts(root_dir: str, group_level: int = 1): groups = list_images_grouped_by_folder(root_dir, group_level=group_level) return {k: len(v) for k,v in groups.items()}