Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
from torch.utils.data import Dataset | |
import cv2 | |
import glob | |
import os | |
def get_roi(mask, margin=8): | |
""" | |
""" | |
h0, w0 = mask.shape[:2] | |
if mask is not None: | |
rows, cols = np.nonzero(mask) | |
rowmin, rowmax = np.min(rows), np.max(rows) | |
colmin, colmax = np.min(cols), np.max(cols) | |
row, col = rowmax - rowmin, colmax - colmin | |
flag = not (rowmin - margin <= 0 or rowmax + margin > h0 or | |
colmin - margin <= 0 or colmax + margin > w0) | |
if row > col and flag: | |
r_s, r_e = rowmin - margin, rowmax + margin | |
c_s, c_e = max(colmin - int(0.5 * (row - col)) - margin, 0), \ | |
min(colmax + int(0.5 * (row - col)) + margin, w0) | |
elif col >= row and flag: | |
r_s, r_e = max(rowmin - int(0.5 * (col - row)) - margin, 0), \ | |
min(rowmax + int(0.5 * (col - row)) + margin, h0) | |
c_s, c_e = colmin - margin, colmax + margin | |
else: | |
r_s, r_e, c_s, c_e = 0, h0, 0, w0 | |
else: | |
r_s, r_e, c_s, c_e = 0, h0, 0, w0 | |
return np.array([h0, w0, r_s, r_e, c_s, c_e]) | |
def crop_and_resize_img(img, roi, max_image_resolution=6000): | |
h0, w0, r_s, r_e, c_s, c_e = roi | |
img = img[r_s:r_e, c_s:c_e, :] | |
h = max(512, min(max_image_resolution, (max(img.shape[:2]) // 512) * 512)) | |
w = h | |
img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC) | |
bit_depth = 255.0 if img.dtype == np.uint8 else 65535.0 if img.dtype == np.uint16 else 1.0 | |
img = np.float32(img) / bit_depth | |
return img | |
def crop_and_resize_mask(mask, roi, max_image_resolution=6000): | |
h0, w0, r_s, r_e, c_s, c_e = roi | |
mask = mask[r_s:r_e, c_s:c_e] | |
h = max(512, min(max_image_resolution, (max(mask.shape[:2]) // 512) * 512)) | |
w = h | |
mask = np.float32(cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC) > 0.5) | |
return mask | |
class DemoData(Dataset): | |
def __init__(self,input_imgs_list,input_mask): | |
self.input_imgs_list = input_imgs_list | |
self.input_mask = input_mask | |
def __len__(self): | |
return 1 | |
def load(self,input_images_list,mask): | |
if mask is None: | |
mask = np.ones_like(input_images_list[0][0]) | |
else: | |
mask = np.array(mask) | |
mask = mask[:,:,0] | |
if mask.max() <= 1.0: | |
self.mask_original = mask[:,:,None] | |
else: | |
self.mask_original = mask[:,:,None] / 255.0 | |
self.roi = get_roi(mask) | |
for i in range(len(input_images_list)): | |
img = input_images_list[i] | |
input_images_list[i]= crop_and_resize_img(img[0], self.roi) | |
I = np.array(input_images_list) | |
numberofimages,h,w,_ = I.shape | |
mask = crop_and_resize_mask(mask, self.roi) | |
I = np.reshape(I, (-1, h * w, 3)) | |
temp = np.mean(I[:, mask.flatten()==1,:], axis=2) | |
mx = np.max(temp, axis=1) | |
temp = mx | |
I /= (temp.reshape(-1,1,1) + 1.0e-6) | |
I = np.transpose(I, (1, 2, 0)) | |
I = I.reshape(h, w, 3, numberofimages) | |
mask = (mask.reshape(h, w, 1)).astype(np.float32) | |
h = mask.shape[0] | |
w = mask.shape[1] | |
self.h = h | |
self.w = w | |
self.I = I | |
self.N = np.ones((h, w, 3), np.float32) | |
self.mask = mask | |
return 1 | |
def __getitem__(self, idx): | |
self.load(self.input_imgs_list,self.input_mask) | |
return { | |
"imgs":self.I.transpose(2,0,1,3), | |
"mask":self.mask.transpose(2,0,1), | |
"mask_original":self.mask_original.transpose(2,0,1), | |
"roi":self.roi | |
} | |
class TestData(Dataset): | |
def __init__( | |
self, | |
data_root: list = None, | |
numofimages: int = 16 | |
): | |
self.data_root = data_root | |
self.numberOfImages = numofimages | |
self.objlist = [] | |
for i in range(len(self.data_root)): | |
with os.scandir(self.data_root[i]) as entries: | |
self.objlist += [entry.path for entry in entries if entry.is_dir()] | |
print(f"[Dataset] => {len(self.objlist)} items selected.") | |
objlist = self.objlist | |
total = len(objlist) | |
indices = list(range(total)) | |
self.objlist = [objlist[i] for i in indices] | |
print(f"Test, => {len(self.objlist)} items selected.") | |
def load(self, objlist, dirid): | |
obj_path = objlist[dirid] | |
if "DiLiGenT" in obj_path: | |
nml_path = os.path.join(obj_path, "Normal_gt.png") | |
if "10" not in obj_path: # diligent | |
directlist = sorted(glob.glob(os.path.join(obj_path, f"0*"))) | |
else: # diligent100 | |
directlist = sorted([ | |
path for path in glob.glob(os.path.join(obj_path, "*.png")) | |
if not os.path.basename(path).lower() == "mask.png" | |
]) | |
elif "LUCES" in obj_path: | |
nml_path = os.path.join(obj_path, "normals.png") | |
directlist = sorted([ | |
f for i in range(1, 52) for f in glob.glob(os.path.join(obj_path, f"{i:02d}*")) | |
]) | |
elif "Real" in obj_path: | |
nml_path = os.path.join(obj_path, "Normal_gt.png") | |
directlist = sorted(glob.glob(os.path.join(obj_path, f"L*"))) | |
else: | |
print(f"error:unknown dataset{obj_path}") | |
return 0 | |
num_images_to_sample = self.numberOfImages | |
if num_images_to_sample is not None and num_images_to_sample < len(directlist): | |
indexset = np.random.permutation(len(directlist))[:num_images_to_sample] | |
else: | |
indexset = range(len(directlist)) | |
I = None | |
mask = None | |
N = None | |
n_true = None | |
for i, indexofimage in enumerate(indexset): | |
img_path = directlist[indexofimage] | |
read_img = cv2.imread(img_path, flags=cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) | |
if read_img is None: | |
print(f"warning: can not read {img_path}") | |
return 0 | |
img = cv2.cvtColor(read_img, cv2.COLOR_BGR2RGB) | |
if i == 0: | |
mask_path = os.path.join(obj_path, "mask.png") | |
if os.path.exists(mask_path): | |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) / 255.0 | |
else: | |
mask = np.ones_like(read_img)[:,:,0] | |
if os.path.exists(nml_path): | |
bit_depth = 65535.0 if "LUCES" in obj_path else 255.0 | |
N = cv2.cvtColor(cv2.imread(nml_path, flags=cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH), cv2.COLOR_BGR2RGB) / bit_depth | |
N = 2 * N - 1 | |
N = N / np.linalg.norm(N, axis=2, keepdims=True) | |
N = N * mask[:, :, np.newaxis] | |
n_true = N | |
self.roi = get_roi(mask) | |
mask = crop_and_resize_mask(mask, self.roi) | |
img= crop_and_resize_img(img, self.roi) | |
h, w = img.shape[:2] | |
if i == 0: | |
I = np.zeros((len(indexset), h, w, 3), np.float32) | |
I[i, :, :, :] = img | |
imgs_ = I.copy() | |
I = np.reshape(I, (-1, h * w, 3)) | |
"""Data Normalization""" | |
temp = np.mean(I[:, mask.flatten()==1,:], axis=2) | |
mean = np.mean(temp, axis=1) | |
mx = np.max(temp, axis=1) | |
scale = np.random.rand(I.shape[0],) | |
temp = (1-scale) * mean + scale * mx | |
imgs_ /= (temp.reshape(-1,1,1,1) + 1.0e-6) | |
I = imgs_ | |
I = np.transpose(I, (1, 2, 3, 0)) | |
mask = (mask.reshape(h, w, 1)).astype(np.float32) | |
h = mask.shape[0] | |
w = mask.shape[1] | |
self.h = h | |
self.w = w | |
self.I = I # | |
if ("DiLiGenT" in obj_path and "10" in obj_path) or "Real" in obj_path: # diligent100 | |
self.N = np.ones((h,w,3,1)) | |
else: | |
self.N = n_true[:,:,:,np.newaxis] | |
self.mask = mask | |
self.directlist = directlist | |
return 1 | |
def __getitem__(self, index_): | |
objid = index_ | |
while 1: | |
success = self.load(self.objlist, objid) | |
if success: | |
break | |
else: | |
objid = np.random.randint(0, len(self.objlist)) | |
img = self.I.transpose(2,0,1,3) # 3 h w Nmax | |
nml = self.N.transpose(2,0,1,3) # 3 h w 1 | |
objname = os.path.basename(os.path.basename(self.objlist[objid])) | |
numberOfImages = self.numberOfImages | |
try: | |
output = { | |
'imgs': img, | |
'nml': nml, | |
"mask":self.mask.transpose(2,0,1), | |
'directlist': self.directlist, | |
'objname': objname, | |
'numberOfImages': numberOfImages, | |
"roi":self.roi | |
} | |
return output | |
except: | |
raise KeyError | |
def __len__(self): | |
return len(self.objlist) |