import numpy as np from torch.utils.data import Dataset import cv2 import glob import os def get_roi(mask, margin=8): """ """ h0, w0 = mask.shape[:2] if mask is not None: rows, cols = np.nonzero(mask) rowmin, rowmax = np.min(rows), np.max(rows) colmin, colmax = np.min(cols), np.max(cols) row, col = rowmax - rowmin, colmax - colmin flag = not (rowmin - margin <= 0 or rowmax + margin > h0 or colmin - margin <= 0 or colmax + margin > w0) if row > col and flag: r_s, r_e = rowmin - margin, rowmax + margin c_s, c_e = max(colmin - int(0.5 * (row - col)) - margin, 0), \ min(colmax + int(0.5 * (row - col)) + margin, w0) elif col >= row and flag: r_s, r_e = max(rowmin - int(0.5 * (col - row)) - margin, 0), \ min(rowmax + int(0.5 * (col - row)) + margin, h0) c_s, c_e = colmin - margin, colmax + margin else: r_s, r_e, c_s, c_e = 0, h0, 0, w0 else: r_s, r_e, c_s, c_e = 0, h0, 0, w0 return np.array([h0, w0, r_s, r_e, c_s, c_e]) def crop_and_resize_img(img, roi, max_image_resolution=6000): h0, w0, r_s, r_e, c_s, c_e = roi img = img[r_s:r_e, c_s:c_e, :] h = max(512, min(max_image_resolution, (max(img.shape[:2]) // 512) * 512)) w = h img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC) bit_depth = 255.0 if img.dtype == np.uint8 else 65535.0 if img.dtype == np.uint16 else 1.0 img = np.float32(img) / bit_depth return img def crop_and_resize_mask(mask, roi, max_image_resolution=6000): h0, w0, r_s, r_e, c_s, c_e = roi mask = mask[r_s:r_e, c_s:c_e] h = max(512, min(max_image_resolution, (max(mask.shape[:2]) // 512) * 512)) w = h mask = np.float32(cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC) > 0.5) return mask class DemoData(Dataset): def __init__(self,input_imgs_list,input_mask): self.input_imgs_list = input_imgs_list self.input_mask = input_mask def __len__(self): return 1 def load(self,input_images_list,mask): if mask is None: mask = np.ones_like(input_images_list[0][0]) else: mask = np.array(mask) mask = mask[:,:,0] if mask.max() <= 1.0: self.mask_original = mask[:,:,None] else: self.mask_original = mask[:,:,None] / 255.0 self.roi = get_roi(mask) for i in range(len(input_images_list)): img = input_images_list[i] input_images_list[i]= crop_and_resize_img(img[0], self.roi) I = np.array(input_images_list) numberofimages,h,w,_ = I.shape mask = crop_and_resize_mask(mask, self.roi) I = np.reshape(I, (-1, h * w, 3)) temp = np.mean(I[:, mask.flatten()==1,:], axis=2) mx = np.max(temp, axis=1) temp = mx I /= (temp.reshape(-1,1,1) + 1.0e-6) I = np.transpose(I, (1, 2, 0)) I = I.reshape(h, w, 3, numberofimages) mask = (mask.reshape(h, w, 1)).astype(np.float32) h = mask.shape[0] w = mask.shape[1] self.h = h self.w = w self.I = I self.N = np.ones((h, w, 3), np.float32) self.mask = mask return 1 def __getitem__(self, idx): self.load(self.input_imgs_list,self.input_mask) return { "imgs":self.I.transpose(2,0,1,3), "mask":self.mask.transpose(2,0,1), "mask_original":self.mask_original.transpose(2,0,1), "roi":self.roi } class TestData(Dataset): def __init__( self, data_root: list = None, numofimages: int = 16 ): self.data_root = data_root self.numberOfImages = numofimages self.objlist = [] for i in range(len(self.data_root)): with os.scandir(self.data_root[i]) as entries: self.objlist += [entry.path for entry in entries if entry.is_dir()] print(f"[Dataset] => {len(self.objlist)} items selected.") objlist = self.objlist total = len(objlist) indices = list(range(total)) self.objlist = [objlist[i] for i in indices] print(f"Test, => {len(self.objlist)} items selected.") def load(self, objlist, dirid): obj_path = objlist[dirid] if "DiLiGenT" in obj_path: nml_path = os.path.join(obj_path, "Normal_gt.png") if "10" not in obj_path: # diligent directlist = sorted(glob.glob(os.path.join(obj_path, f"0*"))) else: # diligent100 directlist = sorted([ path for path in glob.glob(os.path.join(obj_path, "*.png")) if not os.path.basename(path).lower() == "mask.png" ]) elif "LUCES" in obj_path: nml_path = os.path.join(obj_path, "normals.png") directlist = sorted([ f for i in range(1, 52) for f in glob.glob(os.path.join(obj_path, f"{i:02d}*")) ]) elif "Real" in obj_path: nml_path = os.path.join(obj_path, "Normal_gt.png") directlist = sorted(glob.glob(os.path.join(obj_path, f"L*"))) else: print(f"error:unknown dataset{obj_path}") return 0 num_images_to_sample = self.numberOfImages if num_images_to_sample is not None and num_images_to_sample < len(directlist): indexset = np.random.permutation(len(directlist))[:num_images_to_sample] else: indexset = range(len(directlist)) I = None mask = None N = None n_true = None for i, indexofimage in enumerate(indexset): img_path = directlist[indexofimage] read_img = cv2.imread(img_path, flags=cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) if read_img is None: print(f"warning: can not read {img_path}") return 0 img = cv2.cvtColor(read_img, cv2.COLOR_BGR2RGB) if i == 0: mask_path = os.path.join(obj_path, "mask.png") if os.path.exists(mask_path): mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) / 255.0 else: mask = np.ones_like(read_img)[:,:,0] if os.path.exists(nml_path): bit_depth = 65535.0 if "LUCES" in obj_path else 255.0 N = cv2.cvtColor(cv2.imread(nml_path, flags=cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH), cv2.COLOR_BGR2RGB) / bit_depth N = 2 * N - 1 N = N / np.linalg.norm(N, axis=2, keepdims=True) N = N * mask[:, :, np.newaxis] n_true = N self.roi = get_roi(mask) mask = crop_and_resize_mask(mask, self.roi) img= crop_and_resize_img(img, self.roi) h, w = img.shape[:2] if i == 0: I = np.zeros((len(indexset), h, w, 3), np.float32) I[i, :, :, :] = img imgs_ = I.copy() I = np.reshape(I, (-1, h * w, 3)) """Data Normalization""" temp = np.mean(I[:, mask.flatten()==1,:], axis=2) mean = np.mean(temp, axis=1) mx = np.max(temp, axis=1) scale = np.random.rand(I.shape[0],) temp = (1-scale) * mean + scale * mx imgs_ /= (temp.reshape(-1,1,1,1) + 1.0e-6) I = imgs_ I = np.transpose(I, (1, 2, 3, 0)) mask = (mask.reshape(h, w, 1)).astype(np.float32) h = mask.shape[0] w = mask.shape[1] self.h = h self.w = w self.I = I # if ("DiLiGenT" in obj_path and "10" in obj_path) or "Real" in obj_path: # diligent100 self.N = np.ones((h,w,3,1)) else: self.N = n_true[:,:,:,np.newaxis] self.mask = mask self.directlist = directlist return 1 def __getitem__(self, index_): objid = index_ while 1: success = self.load(self.objlist, objid) if success: break else: objid = np.random.randint(0, len(self.objlist)) img = self.I.transpose(2,0,1,3) # 3 h w Nmax nml = self.N.transpose(2,0,1,3) # 3 h w 1 objname = os.path.basename(os.path.basename(self.objlist[objid])) numberOfImages = self.numberOfImages try: output = { 'imgs': img, 'nml': nml, "mask":self.mask.transpose(2,0,1), 'directlist': self.directlist, 'objname': objname, 'numberOfImages': numberOfImages, "roi":self.roi } return output except: raise KeyError def __len__(self): return len(self.objlist)