lino / src /data /data_module.py
algohunt
initial_commit
c295391
import numpy as np
from torch.utils.data import Dataset
import cv2
import glob
import os
def get_roi(mask, margin=8):
"""
"""
h0, w0 = mask.shape[:2]
if mask is not None:
rows, cols = np.nonzero(mask)
rowmin, rowmax = np.min(rows), np.max(rows)
colmin, colmax = np.min(cols), np.max(cols)
row, col = rowmax - rowmin, colmax - colmin
flag = not (rowmin - margin <= 0 or rowmax + margin > h0 or
colmin - margin <= 0 or colmax + margin > w0)
if row > col and flag:
r_s, r_e = rowmin - margin, rowmax + margin
c_s, c_e = max(colmin - int(0.5 * (row - col)) - margin, 0), \
min(colmax + int(0.5 * (row - col)) + margin, w0)
elif col >= row and flag:
r_s, r_e = max(rowmin - int(0.5 * (col - row)) - margin, 0), \
min(rowmax + int(0.5 * (col - row)) + margin, h0)
c_s, c_e = colmin - margin, colmax + margin
else:
r_s, r_e, c_s, c_e = 0, h0, 0, w0
else:
r_s, r_e, c_s, c_e = 0, h0, 0, w0
return np.array([h0, w0, r_s, r_e, c_s, c_e])
def crop_and_resize_img(img, roi, max_image_resolution=6000):
h0, w0, r_s, r_e, c_s, c_e = roi
img = img[r_s:r_e, c_s:c_e, :]
h = max(512, min(max_image_resolution, (max(img.shape[:2]) // 512) * 512))
w = h
img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC)
bit_depth = 255.0 if img.dtype == np.uint8 else 65535.0 if img.dtype == np.uint16 else 1.0
img = np.float32(img) / bit_depth
return img
def crop_and_resize_mask(mask, roi, max_image_resolution=6000):
h0, w0, r_s, r_e, c_s, c_e = roi
mask = mask[r_s:r_e, c_s:c_e]
h = max(512, min(max_image_resolution, (max(mask.shape[:2]) // 512) * 512))
w = h
mask = np.float32(cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC) > 0.5)
return mask
class DemoData(Dataset):
def __init__(self,input_imgs_list,input_mask):
self.input_imgs_list = input_imgs_list
self.input_mask = input_mask
def __len__(self):
return 1
def load(self,input_images_list,mask):
if mask is None:
mask = np.ones_like(input_images_list[0][0])
else:
mask = np.array(mask)
mask = mask[:,:,0]
if mask.max() <= 1.0:
self.mask_original = mask[:,:,None]
else:
self.mask_original = mask[:,:,None] / 255.0
self.roi = get_roi(mask)
for i in range(len(input_images_list)):
img = input_images_list[i]
input_images_list[i]= crop_and_resize_img(img[0], self.roi)
I = np.array(input_images_list)
numberofimages,h,w,_ = I.shape
mask = crop_and_resize_mask(mask, self.roi)
I = np.reshape(I, (-1, h * w, 3))
temp = np.mean(I[:, mask.flatten()==1,:], axis=2)
mx = np.max(temp, axis=1)
temp = mx
I /= (temp.reshape(-1,1,1) + 1.0e-6)
I = np.transpose(I, (1, 2, 0))
I = I.reshape(h, w, 3, numberofimages)
mask = (mask.reshape(h, w, 1)).astype(np.float32)
h = mask.shape[0]
w = mask.shape[1]
self.h = h
self.w = w
self.I = I
self.N = np.ones((h, w, 3), np.float32)
self.mask = mask
return 1
def __getitem__(self, idx):
self.load(self.input_imgs_list,self.input_mask)
return {
"imgs":self.I.transpose(2,0,1,3),
"mask":self.mask.transpose(2,0,1),
"mask_original":self.mask_original.transpose(2,0,1),
"roi":self.roi
}
class TestData(Dataset):
def __init__(
self,
data_root: list = None,
numofimages: int = 16
):
self.data_root = data_root
self.numberOfImages = numofimages
self.objlist = []
for i in range(len(self.data_root)):
with os.scandir(self.data_root[i]) as entries:
self.objlist += [entry.path for entry in entries if entry.is_dir()]
print(f"[Dataset] => {len(self.objlist)} items selected.")
objlist = self.objlist
total = len(objlist)
indices = list(range(total))
self.objlist = [objlist[i] for i in indices]
print(f"Test, => {len(self.objlist)} items selected.")
def load(self, objlist, dirid):
obj_path = objlist[dirid]
if "DiLiGenT" in obj_path:
nml_path = os.path.join(obj_path, "Normal_gt.png")
if "10" not in obj_path: # diligent
directlist = sorted(glob.glob(os.path.join(obj_path, f"0*")))
else: # diligent100
directlist = sorted([
path for path in glob.glob(os.path.join(obj_path, "*.png"))
if not os.path.basename(path).lower() == "mask.png"
])
elif "LUCES" in obj_path:
nml_path = os.path.join(obj_path, "normals.png")
directlist = sorted([
f for i in range(1, 52) for f in glob.glob(os.path.join(obj_path, f"{i:02d}*"))
])
elif "Real" in obj_path:
nml_path = os.path.join(obj_path, "Normal_gt.png")
directlist = sorted(glob.glob(os.path.join(obj_path, f"L*")))
else:
print(f"error:unknown dataset{obj_path}")
return 0
num_images_to_sample = self.numberOfImages
if num_images_to_sample is not None and num_images_to_sample < len(directlist):
indexset = np.random.permutation(len(directlist))[:num_images_to_sample]
else:
indexset = range(len(directlist))
I = None
mask = None
N = None
n_true = None
for i, indexofimage in enumerate(indexset):
img_path = directlist[indexofimage]
read_img = cv2.imread(img_path, flags=cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
if read_img is None:
print(f"warning: can not read {img_path}")
return 0
img = cv2.cvtColor(read_img, cv2.COLOR_BGR2RGB)
if i == 0:
mask_path = os.path.join(obj_path, "mask.png")
if os.path.exists(mask_path):
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) / 255.0
else:
mask = np.ones_like(read_img)[:,:,0]
if os.path.exists(nml_path):
bit_depth = 65535.0 if "LUCES" in obj_path else 255.0
N = cv2.cvtColor(cv2.imread(nml_path, flags=cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH), cv2.COLOR_BGR2RGB) / bit_depth
N = 2 * N - 1
N = N / np.linalg.norm(N, axis=2, keepdims=True)
N = N * mask[:, :, np.newaxis]
n_true = N
self.roi = get_roi(mask)
mask = crop_and_resize_mask(mask, self.roi)
img= crop_and_resize_img(img, self.roi)
h, w = img.shape[:2]
if i == 0:
I = np.zeros((len(indexset), h, w, 3), np.float32)
I[i, :, :, :] = img
imgs_ = I.copy()
I = np.reshape(I, (-1, h * w, 3))
"""Data Normalization"""
temp = np.mean(I[:, mask.flatten()==1,:], axis=2)
mean = np.mean(temp, axis=1)
mx = np.max(temp, axis=1)
scale = np.random.rand(I.shape[0],)
temp = (1-scale) * mean + scale * mx
imgs_ /= (temp.reshape(-1,1,1,1) + 1.0e-6)
I = imgs_
I = np.transpose(I, (1, 2, 3, 0))
mask = (mask.reshape(h, w, 1)).astype(np.float32)
h = mask.shape[0]
w = mask.shape[1]
self.h = h
self.w = w
self.I = I #
if ("DiLiGenT" in obj_path and "10" in obj_path) or "Real" in obj_path: # diligent100
self.N = np.ones((h,w,3,1))
else:
self.N = n_true[:,:,:,np.newaxis]
self.mask = mask
self.directlist = directlist
return 1
def __getitem__(self, index_):
objid = index_
while 1:
success = self.load(self.objlist, objid)
if success:
break
else:
objid = np.random.randint(0, len(self.objlist))
img = self.I.transpose(2,0,1,3) # 3 h w Nmax
nml = self.N.transpose(2,0,1,3) # 3 h w 1
objname = os.path.basename(os.path.basename(self.objlist[objid]))
numberOfImages = self.numberOfImages
try:
output = {
'imgs': img,
'nml': nml,
"mask":self.mask.transpose(2,0,1),
'directlist': self.directlist,
'objname': objname,
'numberOfImages': numberOfImages,
"roi":self.roi
}
return output
except:
raise KeyError
def __len__(self):
return len(self.objlist)