import torchvision import torch from torchmetrics import MeanMetric import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .module.utils import * from .utils import decompose_tensors from .utils import gauss_filter import cv2 import pytorch_lightning as pl from src.models.utils.compute_mae import compute_mae_np from datetime import datetime class LiNo_UniPS(pl.LightningModule): def __init__(self, pixel_samples: int = 2048, task_name :str = None): super().__init__() self.pixel_samples = pixel_samples self.task_name = task_name self.input_dim = 4 self.image_encoder = ScaleInvariantSpatialLightImageEncoder(self.input_dim, use_efficient_attention=False) self.input_dim = 0 self.glc_upsample = GLC_Upsample(256+self.input_dim, num_enc_sab=1, dim_hidden=256, dim_feedforward=1024, use_efficient_attention=True) self.glc_aggregation = GLC_Aggregation(256+self.input_dim, num_agg_transformer=2, dim_aggout=384, dim_feedforward=1024, use_efficient_attention=False) self.img_embedding = nn.Sequential( nn.Linear(3,32), nn.LeakyReLU(), nn.Linear(32, 256) ) self.regressor = Regressor(384, num_enc_sab=1, use_efficient_attention=True, dim_feedforward=1024) self.test_mae = MeanMetric() self.test_loss = MeanMetric() def on_test_start(self): timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") self.run_save_dir = f'output/{timestamp}/{self.task_name}/results/' os.makedirs(self.run_save_dir, exist_ok=True) def from_pretrained(self,pth_path): pretrain_weight = torch.load(pth_path,weights_only=False) self.load_state_dict(pretrain_weight, strict=False) def _prepare_test_inputs(self, batch): img = batch["imgs"].to(torch.bfloat16) self.numberofImages = img.shape[-1] print("number of test images", self.numberofImages) nml = batch["nml"].to(torch.bfloat16) directlist = batch["directlist"] roi = batch.get("roi",None) roi = roi[0].cpu().numpy() return img, nml,directlist,roi def _postprocess_prediction(self, nml_predict_raw, nml_gt_raw, roi): h_orig, w_orig, r_s, r_e, c_s, c_e = roi nml_predict = nml_predict_raw.squeeze().permute(1, 2, 0).cpu().numpy() nml_predict = cv2.resize(nml_predict, dsize=(c_e - c_s, r_e - r_s), interpolation=cv2.INTER_AREA) mask = np.float32(np.abs(1 - np.sqrt(np.sum(nml_predict * nml_predict, axis=2))) < 0.5) nml_predict = np.divide(nml_predict, np.linalg.norm(nml_predict, axis=2, keepdims=True) + 1e-12) nml_predict = nml_predict * mask[:, :, np.newaxis] nout = np.zeros((h_orig, w_orig, 3), np.float32) nout[r_s:r_e, c_s:c_e, :] = nml_predict nml_gt = nml_gt_raw.squeeze().permute(1, 2, 0).float().cpu().numpy() mask_gt = np.float32(np.abs(1 - np.sqrt(np.sum(nml_gt * nml_gt, axis=2))) < 0.5) return nout, nml_gt, mask_gt def _calculate_and_log_metrics(self, nout, nml_gt, mask_gt): mse = torch.nn.MSELoss()(torch.tensor(nout).to(self.device), torch.tensor(nml_gt).to(self.device)) mae, emap = compute_mae_np(nout, nml_gt, mask_gt) self.test_loss(mse) self.test_mae(mae) self.log("test/mse", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("test/mae", self.test_mae, on_step=False, on_epoch=True, prog_bar=True) return mse, mae, emap def _save_test_results(self, nout, nml_gt, emap, img, loss, mae, directlist, save_dir): obj_name_parts = os.path.dirname(directlist[0][0]).split('/') obj_name = obj_name_parts[-1] save_path = os.path.join(save_dir,f'{self.numberofImages}',f'{obj_name}') os.makedirs(save_path, exist_ok=True) print(f"save to: {save_path}") if ("DiLiGenT_100" not in self.task_name) and ("Real" not in self.task_name): nout_to_save = (nout + 1) / 2 nml_gt_to_save = (nml_gt + 1) / 2 emap_to_save = emap.astype(np.float32).squeeze() thresh = 90 emap_to_save[emap_to_save >= thresh] = thresh emap_to_save = emap_to_save / thresh plt.imsave(save_path + '/nml_predict.png', np.clip(nout_to_save, 0, 1)) plt.imsave(save_path + '/nml_gt.png', np.clip(nml_gt_to_save, 0, 1)) plt.imsave(save_path + '/error_map.png', emap_to_save, cmap='jet') torchvision.utils.save_image(img.squeeze(0).permute(3,0,1,2), save_path + '/tiled.png') fig, axes = plt.subplots(1, 3, figsize=(12, 4)) axes[0].imshow(np.clip(nout_to_save, 0, 1)) axes[0].set_title('Prediction'); axes[0].axis('off') axes[1].imshow(np.clip(nml_gt_to_save, 0, 1)) axes[1].set_title('Ground Truth'); axes[1].axis('off') axes[2].imshow(emap, cmap='jet') axes[2].set_title('Error Map'); axes[2].axis('off') plt.figtext(0.5, 0.02, f'Loss: {loss:.4f} | MAE: {mae:.4f}', ha='center', fontsize=12) plt.tight_layout() plt.savefig(save_path + '/combined.png', dpi=300) plt.close(fig) with open(save_path + '/result.txt', 'w') as f: f.write(f"loss: {loss.item()}\n") f.write(f"mae: {mae}\n") print(f"Done for {obj_name}") else: if "DiLiGenT_100" in self.task_name: from scipy.io import savemat mat_save_path = os.path.join(os.path.dirname(save_path),"submit") os.makedirs(mat_save_path,exist_ok=True) normal_map = nout savemat(mat_save_path + "/" + obj_name + '.mat', {'Normal_est': normal_map}) torchvision.utils.save_image(img.squeeze(0).permute(3,0,1,2), save_path + '/tiled.png') nout = (nout + 1) / 2 plt.imsave(save_path + '/nml_predict.png', nout) def test_step(self, batch, batch_idx): img, nml_gt, directlist, roi = self._prepare_test_inputs(batch) nml_predict = self.model_step(batch) nout, nml_gt, mask_gt = self._postprocess_prediction(nml_predict, nml_gt, roi) if ("DiLiGenT_100" not in self.task_name) and ("Real" not in self.task_name): loss, mae, emap = self._calculate_and_log_metrics(nout, nml_gt, mask_gt) print(f"{os.path.basename(os.path.dirname(directlist[0][0]))} | MAE: {mae:.4f}") self._save_test_results(nout, nml_gt, emap, img, loss, mae, directlist, self.run_save_dir) else: emap,loss,mae = None,None,None self._save_test_results(nout, nml_gt, emap, img, loss, mae, directlist, self.run_save_dir) def predict_step(self,batch): roi = batch.get("roi",None) nml_predict = self.model_step(batch) roi = roi[0].cpu().numpy() h_ = roi[0] w_ = roi[1] r_s = roi[2] r_e = roi[3] c_s = roi[4] c_e = roi[5] nml_predict = nml_predict.squeeze().permute(1,2,0).cpu().numpy() nml_predict = cv2.resize(nml_predict, dsize=(c_e-c_s, r_e-r_s), interpolation=cv2.INTER_AREA) nml_predict = np.divide(nml_predict, np.linalg.norm(nml_predict, axis=2, keepdims=True) + 1.0e-12) mask = np.float32(np.abs(1 - np.sqrt(np.sum(nml_predict * nml_predict, axis=2))) < 0.5) nml_predict = nml_predict * mask[:, :, np.newaxis] nout = np.zeros((h_, w_, 3), np.float32) nout[r_s:r_e, c_s:c_e,:] = nml_predict mask = batch["mask_original"].squeeze().cpu().numpy()[:,:,None] return nout*mask def model_step(self,batch): I = batch.get("imgs",None) M = batch.get("mask",None) # roi = batch.get("roi",None) B, C, H, W, Nmax = I.shape patch_size = 512 patches_I = decompose_tensors.divide_tensor_spatial(I.permute(0,4,1,2,3).reshape(-1, C, H, W), block_size=patch_size, method='tile_stride') patches_I = patches_I.reshape(B, Nmax, -1, C, patch_size, patch_size).permute(0, 2, 3, 4, 5, 1) sliding_blocks = patches_I.shape[1] patches_M = decompose_tensors.divide_tensor_spatial(M, block_size=patch_size, method='tile_stride') patches_nml = [] nImgArray = np.array([Nmax]) canonical_resolution = 256 for k in range(sliding_blocks): """ Image Encoder at Canonical Resolution """ print("please wait for a moment, it may take a while") I = patches_I[:, k, :, :, :, :] M = patches_M[:, k, :, :, :] B, C, H, W, Nmax = I.shape decoder_resolution = H I_enc = I.permute(0, 4, 1, 2, 3) M_enc = M img_index = make_index_list(Nmax, nImgArray) I_enc = I_enc.reshape(-1, I_enc.shape[2], I_enc.shape[3], I_enc.shape[4]) M_enc = M_enc.unsqueeze(1).expand(-1, Nmax, -1, -1, -1).reshape(-1, 1, H, W) data = I_enc * M_enc data = data[img_index==1,:,:,:] glc,_= self.image_encoder(data, nImgArray, canonical_resolution) I_dec = [] M_dec = [] img = I.permute(0, 4, 1, 2, 3) """ Sample Decoder at Original Resokution""" img = img.squeeze() I_dec = F.interpolate(img.float(), size=(decoder_resolution, decoder_resolution), mode='bilinear', align_corners=False).to(torch.bfloat16) M_dec = F.interpolate(M.float(), size=(decoder_resolution, decoder_resolution), mode='nearest').to(torch.bfloat16) decoder_imgsize = (decoder_resolution, decoder_resolution) C = img.shape[1] H = decoder_imgsize[0] W = decoder_imgsize[1] nout = torch.zeros(B, H * W, 3).to(I.device) f_scale = decoder_resolution//canonical_resolution smoothing = gauss_filter.gauss_filter(glc.shape[1], 10 * f_scale+1, 1).to(glc.device) chunk_size = 16 processed_chunks = [] for glc_chunk in torch.split(glc, chunk_size, dim=0): smoothed_chunk = smoothing(glc_chunk) processed_chunks.append(smoothed_chunk) glc = torch.cat(processed_chunks, dim=0) del M _, _, H, W = I_dec.shape p = 0 nout = torch.zeros(B, H * W, 3).to(I.device) conf_out = torch.zeros(B, H * W, 1).to(I.device) for b in range(B): target = range(p, p+nImgArray[b]) p = p+nImgArray[b] m_ = M_dec[b, :, :, :].reshape(-1, H * W).permute(1,0) ids = np.nonzero(m_>0)[:,0] ids = ids[np.random.permutation(len(ids))] ids_shuffle = ids[np.random.permutation(len(ids))] num_split = len(ids) // self.pixel_samples + 1 idset = np.array_split(ids_shuffle, num_split) o_ = I_dec[target, :, :, :].reshape(nImgArray[b], C, H * W).permute(2,0,1) for ids in idset: o_ids = o_[ids, :, :] glc_ids = glc[target, :, :, :].permute(2,3,0,1).flatten(0,1)[ids,:,:] o_ids = self.img_embedding(o_ids) x = o_ids + glc_ids glc_ids = self.glc_upsample(x) x = o_ids + glc_ids x = self.glc_aggregation(x) x_n, _, _, conf = self.regressor(x, len(ids)) x_n = F.normalize(x_n, p=2, dim=-1) nout[b, ids, :] = x_n[b,:,:] conf_out[b, ids, :] = conf[b,:,:].to(torch.float32) nout = nout.reshape(B,H,W,3).permute(0,3,1,2) conf_out = conf_out.reshape(B,H,W,1).permute(0,3,1,2) patches_nml.append(nout) patches_nml = torch.stack(patches_nml, dim=1) merged_tensor_nml = decompose_tensors.merge_tensor_spatial(patches_nml.permute(1,0,2,3,4), method='tile_stride') return merged_tensor_nml