Upload 15 files
Browse files- .gitattributes +1 -0
- LUT-Fuse-main/LICENSE +21 -0
- LUT-Fuse-main/README.md +84 -0
- LUT-Fuse-main/assets/framework.png +3 -0
- LUT-Fuse-main/ckpts/fine_tuned_lut.npy +3 -0
- LUT-Fuse-main/ckpts/fine_tuned_lut_original.npy +3 -0
- LUT-Fuse-main/ckpts/generator_context.pth +3 -0
- LUT-Fuse-main/ckpts/generator_context_original.pth +3 -0
- LUT-Fuse-main/data/o_fusion_dataset.py +64 -0
- LUT-Fuse-main/data/simple_dataset.py +218 -0
- LUT-Fuse-main/fine_tune_lut.py +264 -0
- LUT-Fuse-main/requirements.txt +104 -0
- LUT-Fuse-main/scripts/calculate.py +180 -0
- LUT-Fuse-main/scripts/loss_lut.py +196 -0
- LUT-Fuse-main/test_lut.py +93 -0
- LUT-Fuse-main/transforms.py +107 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
LUT-Fuse-main/assets/framework.png filter=lfs diff=lfs merge=lfs -text
|
LUT-Fuse-main/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Yibing Zhang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
LUT-Fuse-main/README.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center">[ICCV 2025] LUT-Fuse</h1>
|
| 2 |
+
<p align="center">
|
| 3 |
+
<em>Towards Extremely Fast Infrared and Visible Image Fusion via Distillation to Learnable Look-Up Tables</em>
|
| 4 |
+
</p>
|
| 5 |
+
|
| 6 |
+
<p align="center">
|
| 7 |
+
<a href="https://github.com/zyb5/LUT-Fuse" style="text-decoration:none;">
|
| 8 |
+
<img src="https://img.shields.io/badge/GitHub-Code-black?logo=github" alt="Code" />
|
| 9 |
+
</a>
|
| 10 |
+
<a href="https://arxiv.org/abs/2509.00346" style="text-decoration:none; margin-left:8px;">
|
| 11 |
+
<img src="https://img.shields.io/badge/arXiv-Paper-B31B1B?logo=arxiv" alt="Paper" />
|
| 12 |
+
</a>
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
<p align="center">
|
| 16 |
+
<img src="assets/framework.png" alt="LUT-Fuse Framework" width="90%">
|
| 17 |
+
</p>
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## ⚙️ Environment
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
conda create -n lutfuse python=3.8
|
| 25 |
+
conda activate lutfuse
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.8 -c pytorch -c nvidia
|
| 30 |
+
pip install -r requirements.txt
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## 📂 Dataset
|
| 34 |
+
|
| 35 |
+
You should list your dataset as followed rule:
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
|dataset
|
| 39 |
+
|train
|
| 40 |
+
|Infrared
|
| 41 |
+
|Visible
|
| 42 |
+
|Fuse_ref
|
| 43 |
+
|test
|
| 44 |
+
|Infrared
|
| 45 |
+
|Visible
|
| 46 |
+
|Fuse_ref
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## 💾 Checkpoints
|
| 50 |
+
|
| 51 |
+
We provide our **pretrained checkpoints** directly in this repository for convenience.
|
| 52 |
+
You can find them under [`./ckpts`](./ckpts).
|
| 53 |
+
|
| 54 |
+
- **Fusion LUT weights:** `ckpts/fine_tuned_lut.npy`
|
| 55 |
+
- **Context generator weights:** `ckpts/generator_context.pth`
|
| 56 |
+
|
| 57 |
+
## 🧪 Test
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
CUDA_VISIBLE_DEVICES=0 python test_lut.py
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## 🚀 Train
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
CUDA_VISIBLE_DEVICES=0 python fine_tune_lut.py
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## 📖 Citation
|
| 70 |
+
|
| 71 |
+
If you find our work or dataset useful for your research, please cite our paper.
|
| 72 |
+
|
| 73 |
+
```bibtex
|
| 74 |
+
@inproceedings{yi2025LUT-Fuse,
|
| 75 |
+
title={LUT-Fuse: Towards Extremely Fast Infrared and Visible Image Fusion via Distillation to Learnable Look-Up Tables},
|
| 76 |
+
author={Yi, Xunpeng and Zhang, Yibing and Xiang, Xinyu and Yan, Qinglong and Xu, Han and Ma, Jiayi},
|
| 77 |
+
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
| 78 |
+
year={2025}
|
| 79 |
+
}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
If you have any questions, please send an email to zhangyibing@whu.edu.cn
|
| 83 |
+
|
| 84 |
+
|
LUT-Fuse-main/assets/framework.png
ADDED
|
Git LFS Details
|
LUT-Fuse-main/ckpts/fine_tuned_lut.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24f70326cb33b27285e5157594d31a79ee94b67b043828cb98a7d60aeec920e4
|
| 3 |
+
size 262272
|
LUT-Fuse-main/ckpts/fine_tuned_lut_original.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ffa74433659b3a0d8e39ee1f2f4fda2d776ad2967a4e1d9190a67e314099a43
|
| 3 |
+
size 262272
|
LUT-Fuse-main/ckpts/generator_context.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fcacc0280f6ee37d42ae7b82166484f9fe1660526c647c560e051ff23885324
|
| 3 |
+
size 38143
|
LUT-Fuse-main/ckpts/generator_context_original.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7a4ec41712435defb95d3c9cfc23839cf79bc00ae22ba47cc5755d480525780
|
| 3 |
+
size 37559
|
LUT-Fuse-main/data/o_fusion_dataset.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import os
|
| 5 |
+
from glob import glob
|
| 6 |
+
from torchvision.transforms import RandomCrop
|
| 7 |
+
import torchvision.transforms.functional as F
|
| 8 |
+
|
| 9 |
+
class RandomCropPair:
|
| 10 |
+
def __init__(self, size):
|
| 11 |
+
self.size = size # 裁剪尺寸 (h, w)
|
| 12 |
+
|
| 13 |
+
def __call__(self, vis_img, ir_img, fuse_image):
|
| 14 |
+
# 获取随机裁剪参数
|
| 15 |
+
i, j, h, w = RandomCrop.get_params(vis_img, output_size=self.size)
|
| 16 |
+
# 对可见光和红外图像使用相同的裁剪参数
|
| 17 |
+
vis_img = F.crop(vis_img, i, j, h, w)
|
| 18 |
+
ir_img = F.crop(ir_img, i, j, h, w)
|
| 19 |
+
fuse_image = F.crop(fuse_image, i, j, h, w)
|
| 20 |
+
|
| 21 |
+
vis_img = F.to_tensor(vis_img)
|
| 22 |
+
ir_img = F.to_tensor(ir_img)
|
| 23 |
+
fuse_image = F.to_tensor(fuse_image)
|
| 24 |
+
return vis_img, ir_img, fuse_image
|
| 25 |
+
|
| 26 |
+
class DistillDataSet(Dataset):
|
| 27 |
+
def __init__(self, visible_path, infrared_path, other_fuse_path, phase="train", transform=None):
|
| 28 |
+
self.phase = phase
|
| 29 |
+
self.visible_files = sorted(glob(os.path.join(visible_path, "*.*")))
|
| 30 |
+
self.infrared_files = sorted(glob(os.path.join(infrared_path, "*.*")))
|
| 31 |
+
self.other_fuse_files = sorted(glob(os.path.join(other_fuse_path, "*.*")))
|
| 32 |
+
self.transform = transform
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
l = len(self.infrared_files)
|
| 36 |
+
return l
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, item):
|
| 39 |
+
image_A_path = self.visible_files[item]
|
| 40 |
+
image_B_path = self.infrared_files[item]
|
| 41 |
+
other_fuse_path = self.other_fuse_files[item]
|
| 42 |
+
image_A = Image.open(image_A_path).convert(mode='RGB')
|
| 43 |
+
image_B = Image.open(image_B_path).convert(mode='L') ##########
|
| 44 |
+
other_fuse = Image.open(other_fuse_path).convert(mode='RGB')
|
| 45 |
+
|
| 46 |
+
if self.transform is not None:
|
| 47 |
+
if isinstance(self.transform, RandomCropPair):
|
| 48 |
+
image_A, image_B, other_fuse = self.transform(image_A, image_B, other_fuse)
|
| 49 |
+
else:
|
| 50 |
+
image_A = self.transform(image_A)
|
| 51 |
+
image_B = self.transform(image_B)
|
| 52 |
+
other_fuse = self.transform(other_fuse)
|
| 53 |
+
|
| 54 |
+
name = image_A_path.replace("\\", "/").split("/")[-1].split(".")[0]
|
| 55 |
+
|
| 56 |
+
return image_A, image_B, other_fuse, name
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def collate_fn(batch):
|
| 60 |
+
images_A, images_B, other_fuse, name = zip(*batch)
|
| 61 |
+
images_A = torch.stack(images_A, dim=0)
|
| 62 |
+
images_B = torch.stack(images_B, dim=0)
|
| 63 |
+
other_fuse = torch.stack(other_fuse, dim=0)
|
| 64 |
+
return images_A, images_B, other_fuse, name
|
LUT-Fuse-main/data/simple_dataset.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import os
|
| 5 |
+
from glob import glob
|
| 6 |
+
import transforms as T
|
| 7 |
+
from torchvision.transforms import RandomCrop
|
| 8 |
+
import torchvision.transforms.functional as F
|
| 9 |
+
|
| 10 |
+
# class SimpleDataSet(Dataset):
|
| 11 |
+
# def __init__(self, visible_path, visible_gt_path, infrared_path, phase="train", transform=None):
|
| 12 |
+
# self.phase = phase
|
| 13 |
+
# self.visible_files = sorted(glob(os.path.join(visible_path, "*")))
|
| 14 |
+
# self.visible_gt_files = sorted(glob(os.path.join(visible_gt_path, "*")))
|
| 15 |
+
# self.infrared_files = sorted(glob(os.path.join(infrared_path, "*")))
|
| 16 |
+
# self.transform = transform
|
| 17 |
+
#
|
| 18 |
+
# def __len__(self):
|
| 19 |
+
# l = len(self.infrared_files)
|
| 20 |
+
# return l
|
| 21 |
+
#
|
| 22 |
+
# def __getitem__(self, item):
|
| 23 |
+
# image_A_path = self.visible_files[item]
|
| 24 |
+
# image_A_gt_path = self.visible_gt_files[item]
|
| 25 |
+
# image_B_path = self.infrared_files[item]
|
| 26 |
+
# image_A = Image.open(image_A_path).convert(mode='RGB')
|
| 27 |
+
# image_A_gt = Image.open(image_A_gt_path).convert(mode='RGB')
|
| 28 |
+
# image_B = Image.open(image_B_path).convert(mode='L') ##########
|
| 29 |
+
#
|
| 30 |
+
# image_A = self.transform(image_A)
|
| 31 |
+
# image_A_gt = self.transform(image_A_gt)
|
| 32 |
+
# image_B = self.transform(image_B)
|
| 33 |
+
#
|
| 34 |
+
# name = image_A_path.replace("\\", "/").split("/")[-1].split(".")[0]
|
| 35 |
+
#
|
| 36 |
+
# return image_A, image_A_gt, image_B, name
|
| 37 |
+
#
|
| 38 |
+
# @staticmethod
|
| 39 |
+
# def collate_fn(batch):
|
| 40 |
+
# images_A, image_A_gt, images_B, name = zip(*batch)
|
| 41 |
+
# images_A = torch.stack(images_A, dim=0)
|
| 42 |
+
# image_A_gt = torch.stack(image_A_gt, dim=0)
|
| 43 |
+
# images_B = torch.stack(images_B, dim=0)
|
| 44 |
+
# return images_A, image_A_gt, images_B, name
|
| 45 |
+
|
| 46 |
+
# class RandomCropPair:
|
| 47 |
+
# def __init__(self, size):
|
| 48 |
+
# self.size = size # 裁剪尺寸 (h, w)
|
| 49 |
+
#
|
| 50 |
+
# def __call__(self, vis_img, ir_img):
|
| 51 |
+
# # 获取随机裁剪参数
|
| 52 |
+
# i, j, h, w = RandomCrop.get_params(vis_img, output_size=self.size)
|
| 53 |
+
# # 对可见光和红外图像使用相同的裁剪参数
|
| 54 |
+
# vis_img = F.crop(vis_img, i, j, h, w)
|
| 55 |
+
# ir_img = F.crop(ir_img, i, j, h, w)
|
| 56 |
+
#
|
| 57 |
+
# vis_img = F.to_tensor(vis_img)
|
| 58 |
+
# ir_img = F.to_tensor(ir_img)
|
| 59 |
+
# return vis_img, ir_img
|
| 60 |
+
#
|
| 61 |
+
# class SimpleDataSet(Dataset):
|
| 62 |
+
# def __init__(self, visible_path, infrared_path, phase="train", transform=None):
|
| 63 |
+
# self.phase = phase
|
| 64 |
+
# self.visible_files = sorted(glob(os.path.join(visible_path, "*.*")))
|
| 65 |
+
# self.infrared_files = sorted(glob(os.path.join(infrared_path, "*.*")))
|
| 66 |
+
# self.transform = transform
|
| 67 |
+
#
|
| 68 |
+
# def __len__(self):
|
| 69 |
+
# l = len(self.infrared_files)
|
| 70 |
+
# return l
|
| 71 |
+
#
|
| 72 |
+
# def __getitem__(self, item):
|
| 73 |
+
# image_A_path = self.visible_files[item]
|
| 74 |
+
# image_B_path = self.infrared_files[item]
|
| 75 |
+
# image_A = Image.open(image_A_path).convert(mode='RGB')
|
| 76 |
+
# image_B = Image.open(image_B_path).convert(mode='L') ##########
|
| 77 |
+
#
|
| 78 |
+
# # image_A = self.transform(image_A)
|
| 79 |
+
# # image_B = self.transform(image_B)
|
| 80 |
+
#
|
| 81 |
+
# if self.transform is not None:
|
| 82 |
+
# if isinstance(self.transform, RandomCropPair):
|
| 83 |
+
# image_A, image_B = self.transform(image_A, image_B)
|
| 84 |
+
# else:
|
| 85 |
+
# image_A = self.transform(image_A)
|
| 86 |
+
# image_B = self.transform(image_B)
|
| 87 |
+
#
|
| 88 |
+
# name = image_A_path.replace("\\", "/").split("/")[-1].split(".")[0]
|
| 89 |
+
#
|
| 90 |
+
# return image_A, image_B, name
|
| 91 |
+
#
|
| 92 |
+
# @staticmethod
|
| 93 |
+
# def collate_fn(batch):
|
| 94 |
+
# images_A, images_B, name = zip(*batch)
|
| 95 |
+
# images_A = torch.stack(images_A, dim=0)
|
| 96 |
+
# images_B = torch.stack(images_B, dim=0)
|
| 97 |
+
# return images_A, images_B, name
|
| 98 |
+
#
|
| 99 |
+
#
|
| 100 |
+
# class RandomCropPair:
|
| 101 |
+
# def __init__(self, size):
|
| 102 |
+
# self.size = size # 裁剪尺寸 (h, w)
|
| 103 |
+
#
|
| 104 |
+
# def __call__(self, vis_blur_img, ir_blur_img, vis_gt_img, ir_gt_img):
|
| 105 |
+
# # 获取随机裁剪参数
|
| 106 |
+
# i, j, h, w = RandomCrop.get_params(vis_blur_img, output_size=self.size)
|
| 107 |
+
# # 对可见光和红外图像使用相同的裁剪参数
|
| 108 |
+
# vis_blur_img = F.crop(vis_blur_img, i, j, h, w)
|
| 109 |
+
# ir_blur_img = F.crop(ir_blur_img, i, j, h, w)
|
| 110 |
+
# vis_gt_img = F.crop(vis_gt_img, i, j, h, w)
|
| 111 |
+
# ir_gt_img = F.crop(ir_gt_img, i, j, h, w)
|
| 112 |
+
#
|
| 113 |
+
# vis_blur_img = F.to_tensor(vis_blur_img)
|
| 114 |
+
# ir_blur_img = F.to_tensor(ir_blur_img)
|
| 115 |
+
# vis_gt_img = F.to_tensor(vis_gt_img)
|
| 116 |
+
# ir_gt_img = F.to_tensor(ir_gt_img)
|
| 117 |
+
# return vis_blur_img, ir_blur_img, vis_gt_img, ir_gt_img
|
| 118 |
+
#
|
| 119 |
+
# class SimpleDataSet(Dataset):
|
| 120 |
+
# def __init__(self, visible_blur_path, infrared_blur_path, visible_gt_path, infrared_gt_path, phase="train", transform=None):
|
| 121 |
+
# self.phase = phase
|
| 122 |
+
# self.visible_blur_files = sorted(glob(os.path.join(visible_blur_path, "*.*")))
|
| 123 |
+
# self.infrared_blur_files = sorted(glob(os.path.join(infrared_blur_path, "*.*")))
|
| 124 |
+
# self.visible_gt_files = sorted(glob(os.path.join(visible_gt_path, "*.*")))
|
| 125 |
+
# self.infrared_gt_files = sorted(glob(os.path.join(infrared_gt_path, "*.*")))
|
| 126 |
+
# self.transform = transform
|
| 127 |
+
#
|
| 128 |
+
# def __len__(self):
|
| 129 |
+
# l = len(self.infrared_gt_files)
|
| 130 |
+
# return l
|
| 131 |
+
#
|
| 132 |
+
# def __getitem__(self, item):
|
| 133 |
+
# image_A_blur_path = self.visible_blur_files[item]
|
| 134 |
+
# image_B_blur_path = self.infrared_blur_files[item]
|
| 135 |
+
# image_A_gt_path = self.visible_gt_files[item]
|
| 136 |
+
# image_B_gt_path = self.infrared_gt_files[item]
|
| 137 |
+
# image_A_blur = Image.open(image_A_blur_path).convert(mode='RGB')
|
| 138 |
+
# image_B_blur = Image.open(image_B_blur_path).convert(mode='L') ##########
|
| 139 |
+
# image_A_gt = Image.open(image_A_gt_path).convert(mode='RGB')
|
| 140 |
+
# image_B_gt = Image.open(image_B_gt_path).convert(mode='L') ##########
|
| 141 |
+
#
|
| 142 |
+
# if self.transform is not None:
|
| 143 |
+
# if isinstance(self.transform, RandomCropPair):
|
| 144 |
+
# image_A_blur, image_B_blur, image_A_gt, image_B_gt = self.transform(image_A_blur, image_B_blur, image_A_gt, image_B_gt)
|
| 145 |
+
# else:
|
| 146 |
+
# image_A_blur = self.transform(image_A_blur)
|
| 147 |
+
# image_B_blur = self.transform(image_B_blur)
|
| 148 |
+
# image_A_gt = self.transform(image_A_gt)
|
| 149 |
+
# image_B_gt = self.transform(image_B_gt)
|
| 150 |
+
#
|
| 151 |
+
# name = image_A_blur_path.replace("\\", "/").split("/")[-1].split(".")[0]
|
| 152 |
+
#
|
| 153 |
+
# return image_A_blur, image_B_blur, image_A_gt, image_B_gt, name
|
| 154 |
+
#
|
| 155 |
+
# @staticmethod
|
| 156 |
+
# def collate_fn(batch):
|
| 157 |
+
# image_A_blur, image_B_blur, image_A_gt, image_B_gt, name = zip(*batch)
|
| 158 |
+
# image_A_blur = torch.stack(image_A_blur, dim=0)
|
| 159 |
+
# image_B_blur = torch.stack(image_B_blur, dim=0)
|
| 160 |
+
# image_A_gt = torch.stack(image_A_gt, dim=0)
|
| 161 |
+
# image_B_gt = torch.stack(image_B_gt, dim=0)
|
| 162 |
+
# return image_A_blur, image_B_blur, image_A_gt, image_B_gt, name
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class RandomCropPair:
|
| 166 |
+
def __init__(self, size):
|
| 167 |
+
self.size = size # 裁剪尺寸 (h, w)
|
| 168 |
+
|
| 169 |
+
def __call__(self, vis_img, ir_img):
|
| 170 |
+
# 获取随机裁剪参数
|
| 171 |
+
i, j, h, w = RandomCrop.get_params(vis_img, output_size=self.size)
|
| 172 |
+
# 对可见光和红外图像使用相同的裁剪参数
|
| 173 |
+
vis_img = F.crop(vis_img, i, j, h, w)
|
| 174 |
+
ir_img = F.crop(ir_img, i, j, h, w)
|
| 175 |
+
|
| 176 |
+
vis_img = F.to_tensor(vis_img)
|
| 177 |
+
ir_img = F.to_tensor(ir_img)
|
| 178 |
+
return vis_img, ir_img
|
| 179 |
+
|
| 180 |
+
class SimpleDataSet(Dataset):
|
| 181 |
+
def __init__(self, visible_path, infrared_path, phase="train", transform=None):
|
| 182 |
+
self.phase = phase
|
| 183 |
+
self.visible_files = sorted(glob(os.path.join(visible_path, "*.*")))
|
| 184 |
+
self.infrared_files = sorted(glob(os.path.join(infrared_path, "*.*")))
|
| 185 |
+
self.transform = transform
|
| 186 |
+
|
| 187 |
+
def __len__(self):
|
| 188 |
+
l = len(self.infrared_files)
|
| 189 |
+
return l
|
| 190 |
+
|
| 191 |
+
def __getitem__(self, item):
|
| 192 |
+
image_A_path = self.visible_files[item]
|
| 193 |
+
image_B_path = self.infrared_files[item]
|
| 194 |
+
image_A = Image.open(image_A_path).convert(mode='RGB')
|
| 195 |
+
image_B = Image.open(image_B_path).convert(mode='L') ##########
|
| 196 |
+
|
| 197 |
+
# image_A = self.transform(image_A)
|
| 198 |
+
# image_B = self.transform(image_B)
|
| 199 |
+
|
| 200 |
+
if self.transform is not None:
|
| 201 |
+
if isinstance(self.transform, RandomCropPair):
|
| 202 |
+
image_A, image_B = self.transform(image_A, image_B)
|
| 203 |
+
else:
|
| 204 |
+
image_A = self.transform(image_A)
|
| 205 |
+
image_B = self.transform(image_B)
|
| 206 |
+
|
| 207 |
+
name = image_A_path.replace("\\", "/").split("/")[-1].split(".")[0]
|
| 208 |
+
|
| 209 |
+
return image_A, image_B, name
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def collate_fn(batch):
|
| 213 |
+
images_A, images_B, name = zip(*batch)
|
| 214 |
+
images_A = torch.stack(images_A, dim=0)
|
| 215 |
+
images_B = torch.stack(images_B, dim=0)
|
| 216 |
+
return images_A, images_B, name
|
| 217 |
+
|
| 218 |
+
|
LUT-Fuse-main/fine_tune_lut.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 7 |
+
|
| 8 |
+
from data.o_fusion_dataset import DistillDataSet
|
| 9 |
+
from data.o_fusion_dataset import RandomCropPair
|
| 10 |
+
import datetime
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
import transforms as T
|
| 14 |
+
from scripts.loss_lut import fusion_loss
|
| 15 |
+
from itertools import chain
|
| 16 |
+
from scripts.calculate import OptimizableLUT, Generator_for_info, apply_fusion_4d_with_interpolation
|
| 17 |
+
|
| 18 |
+
cuda = True if torch.cuda.is_available() else False
|
| 19 |
+
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TV_4D(nn.Module):
|
| 23 |
+
def __init__(self, dim=16, output_channels=3):
|
| 24 |
+
super(TV_4D, self).__init__()
|
| 25 |
+
|
| 26 |
+
self.weight_r = torch.ones(dim, dim, dim, dim - 1, output_channels, dtype=torch.float)
|
| 27 |
+
self.weight_r[..., (0, dim - 2), :] *= 2.0
|
| 28 |
+
|
| 29 |
+
self.weight_g = torch.ones(dim, dim, dim - 1, dim, output_channels, dtype=torch.float)
|
| 30 |
+
self.weight_g[..., (0, dim - 2), :, :] *= 2.0
|
| 31 |
+
|
| 32 |
+
self.weight_b = torch.ones(dim, dim - 1, dim, dim, output_channels, dtype=torch.float)
|
| 33 |
+
self.weight_b[..., (0, dim - 2), :, :, :] *= 2.0
|
| 34 |
+
|
| 35 |
+
self.weight_ir = torch.ones(dim - 1, dim, dim, dim, output_channels, dtype=torch.float)
|
| 36 |
+
self.weight_ir[(0, dim - 2), :, :, :, :] *= 2.0
|
| 37 |
+
|
| 38 |
+
self.relu = torch.nn.ReLU()
|
| 39 |
+
|
| 40 |
+
def forward(self, LUT):
|
| 41 |
+
device = LUT.device
|
| 42 |
+
|
| 43 |
+
self.weight_r = self.weight_r.to(device)
|
| 44 |
+
|
| 45 |
+
self.weight_g = self.weight_g.to(device)
|
| 46 |
+
self.weight_b = self.weight_b.to(device)
|
| 47 |
+
self.weight_ir = self.weight_ir.to(device)
|
| 48 |
+
|
| 49 |
+
dif_r = LUT[ :, :, :, :-1, :] - LUT[ :, :, :, 1:, :]
|
| 50 |
+
dif_g = LUT[ :, :, :-1, :, :] - LUT[ :, :, 1:, :, :]
|
| 51 |
+
dif_b = LUT[ :, :-1, :, :, :] - LUT[ :, 1:, :, :, :]
|
| 52 |
+
dif_ir = LUT[ :-1, :, :, :, :] - LUT[ 1:, :, :, :, :]
|
| 53 |
+
|
| 54 |
+
tv = (torch.mean(torch.mul(dif_r ** 2, self.weight_r)) + torch.mean(torch.mul(dif_g ** 2, self.weight_g)) +
|
| 55 |
+
torch.mean(torch.mul(dif_b ** 2, self.weight_b)) + torch.mean(torch.mul(dif_ir ** 2, self.weight_ir)))
|
| 56 |
+
|
| 57 |
+
mn = (torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) +
|
| 58 |
+
torch.mean(self.relu(dif_b)) + torch.mean(self.relu(dif_ir)))
|
| 59 |
+
|
| 60 |
+
return tv, mn
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def fine_tune_lut(lut_model, Generator_context, train_loader, val_loader, device, epochs, learning_rate, save_dir="ww"):
|
| 64 |
+
TV4 = TV_4D().to(device)
|
| 65 |
+
best_val_loss = 1e5
|
| 66 |
+
Generator_context.train()
|
| 67 |
+
loss_fuction = fusion_loss()
|
| 68 |
+
optimizer = optim.Adam(chain(lut_model.parameters(), Generator_context.parameters()), lr=learning_rate)
|
| 69 |
+
# optimizer = optim.Adam(lut_model.parameters(), lr=learning_rate)
|
| 70 |
+
|
| 71 |
+
for epoch in range(epochs):
|
| 72 |
+
lut_model.train()
|
| 73 |
+
|
| 74 |
+
train_loss = 0
|
| 75 |
+
# train_loss_max = 0
|
| 76 |
+
# train_loss_text = 0
|
| 77 |
+
train_loss_l1 = 0
|
| 78 |
+
train_loss_ssim = 0
|
| 79 |
+
train_loss_tv0 = 0
|
| 80 |
+
train_loss_mn0 = 0
|
| 81 |
+
|
| 82 |
+
for step, data in enumerate(train_loader):
|
| 83 |
+
I_A, I_B, fuse, _ = data
|
| 84 |
+
# optimizer.zero_grad()
|
| 85 |
+
|
| 86 |
+
if torch.cuda.is_available():
|
| 87 |
+
I_A = I_A.to(device)
|
| 88 |
+
I_B = I_B.to(device)
|
| 89 |
+
high_quality = fuse.to(device)
|
| 90 |
+
loss_fuction = loss_fuction.to(device)
|
| 91 |
+
|
| 92 |
+
lut = lut_model()
|
| 93 |
+
|
| 94 |
+
tv0, mn0 = TV4(lut)
|
| 95 |
+
loss_tv0 = tv0
|
| 96 |
+
loss_mn0 = mn0
|
| 97 |
+
|
| 98 |
+
outputs = apply_fusion_4d_with_interpolation(I_A * 255., I_B * 255., lut, Generator_context)
|
| 99 |
+
|
| 100 |
+
l1 = F.l1_loss(outputs, high_quality)
|
| 101 |
+
ssim = loss_fuction(I_A, I_B, outputs)
|
| 102 |
+
loss_all = l1 + ssim + 10.0 * loss_mn0 + 0.0001 * loss_tv0 #+ text_loss + loss_max
|
| 103 |
+
|
| 104 |
+
loss_all.backward()
|
| 105 |
+
optimizer.step()
|
| 106 |
+
|
| 107 |
+
train_loss += loss_all.item()
|
| 108 |
+
train_loss_l1 += l1.item()
|
| 109 |
+
train_loss_ssim += ssim.item()
|
| 110 |
+
# train_loss_text += text_loss.item()
|
| 111 |
+
# train_loss_max += loss_max.item()
|
| 112 |
+
train_loss_tv0 += loss_tv0.item()
|
| 113 |
+
train_loss_mn0 += loss_mn0.item()
|
| 114 |
+
# train_loss_color += loss_color.item()
|
| 115 |
+
|
| 116 |
+
tb_writer.add_scalar("train_total_loss", train_loss/len(train_loader), epoch)
|
| 117 |
+
tb_writer.add_scalar("train_loss_l1", train_loss_l1/len(train_loader), epoch)
|
| 118 |
+
tb_writer.add_scalar("train_loss_ssim", train_loss_ssim / len(train_loader), epoch)
|
| 119 |
+
# tb_writer.add_scalar("train_loss_text", train_loss_text / len(train_loader), epoch)
|
| 120 |
+
# tb_writer.add_scalar("train_loss_max", train_loss_max / len(train_loader), epoch)
|
| 121 |
+
tb_writer.add_scalar("train_loss_tv0", train_loss_tv0/len(train_loader), epoch)
|
| 122 |
+
tb_writer.add_scalar("train_loss_mn0", train_loss_mn0/len(train_loader), epoch)
|
| 123 |
+
|
| 124 |
+
print(f"Epoch {epoch + 1}/{epochs} - Loss: {train_loss / len(train_loader):.6f} - loss_l1: {train_loss_l1 / len(train_loader):.6f} - loss_ssim: {train_loss_ssim / len(train_loader):.6f} - loss_tv: {train_loss_tv0 / len(train_loader):.6f} - loss_tv: {train_loss_tv0 / len(train_loader):.6f} ")
|
| 125 |
+
# print(f"Epoch {epoch + 1}/{epochs} - Loss: {train_loss / len(train_loader):.6f} - l1: {train_loss_l1 / len(train_loader):.6f} - loss_text: {train_loss_text / len(train_loader):.6f} - loss_max: {train_loss_max / len(train_loader):.6f}")
|
| 126 |
+
# print(f"Epoch {epoch + 1}/{epochs} - Loss: {train_loss / len(train_loader):.6f} - l1: {train_loss_l1 / len(train_loader):.6f} - tv: {train_loss_tv0 / len(train_loader):.6f} - mn: {train_loss_mn0 / len(train_loader):.6f}")
|
| 127 |
+
|
| 128 |
+
if (epoch + 1) % 10 == 0:
|
| 129 |
+
val_loss, val_loss_l1, val_loss_ssim, val_loss_tv0, val_loss_mn0 = validate_lut(lut_model, val_loader, device)
|
| 130 |
+
tb_writer.add_scalar("val_total_loss", val_loss/len(val_loader), epoch)
|
| 131 |
+
tb_writer.add_scalar("val_loss_l1", val_loss_l1/len(val_loader), epoch)
|
| 132 |
+
tb_writer.add_scalar("val_loss_ssim", val_loss_ssim / len(val_loader), epoch)
|
| 133 |
+
# tb_writer.add_scalar("val_loss_text", val_loss_text / len(val_loader), epoch)
|
| 134 |
+
# tb_writer.add_scalar("val_loss_max", val_loss_max / len(val_loader), epoch)
|
| 135 |
+
tb_writer.add_scalar("val_loss_tv0", val_loss_tv0/len(val_loader), epoch)
|
| 136 |
+
tb_writer.add_scalar("val_loss_mn0", val_loss_mn0/len(val_loader), epoch)
|
| 137 |
+
# print(f"Validation - Epoch {epoch} - Loss: {val_loss / len(val_loader):.6f} - l1: {val_loss_l1 / len(val_loader):.6f} - tv: {val_loss_tv0 / len(val_loader):.6f} - mn: {val_loss_mn0 / len(val_loader):.6f}")
|
| 138 |
+
|
| 139 |
+
if val_loss < best_val_loss :
|
| 140 |
+
best_val_loss = val_loss
|
| 141 |
+
filename = f"fine_tuned_ygcy_epoch{epoch}_valloss{val_loss:.6f}.npy"
|
| 142 |
+
full_path = os.path.join(filefold_path, filename)
|
| 143 |
+
save_lut(lut_model, full_path)
|
| 144 |
+
|
| 145 |
+
context_filename = f"generator_context_epoch{epoch}_valloss{val_loss:.6f}.pth"
|
| 146 |
+
generator_context_save_path = os.path.join(filefold_path, context_filename)
|
| 147 |
+
save_generator_context(Generator_context, save_path=generator_context_save_path)
|
| 148 |
+
|
| 149 |
+
print(f"Validation - Epoch {epoch} - Loss: {val_loss / len(val_loader):.6f} - l1: {val_loss_l1 / len(val_loader):.6f} - loss_ssim: {val_loss_ssim / len(val_loader):.6f} - loss_tv0: {val_loss_tv0 / len(val_loader):.6f} - loss_mn0: {val_loss_mn0 / len(val_loader):.6f}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def save_lut(lut_module, path):
|
| 153 |
+
|
| 154 |
+
lut_weights = lut_module().detach().cpu().numpy()
|
| 155 |
+
np.save(path, lut_weights)
|
| 156 |
+
print(f"Fine-tuned LUT saved to {path}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def validate_lut(lut_module, val_loader, device):
|
| 160 |
+
train_loss = 0
|
| 161 |
+
train_loss_mn0 = 0
|
| 162 |
+
train_loss_tv0 = 0
|
| 163 |
+
train_loss_ssim = 0
|
| 164 |
+
train_loss_l1 = 0
|
| 165 |
+
# train_loss_text = 0
|
| 166 |
+
TV4 = TV_4D().to(device)
|
| 167 |
+
|
| 168 |
+
loss_fuction = fusion_loss()
|
| 169 |
+
Generator_context.eval()
|
| 170 |
+
|
| 171 |
+
lut = lut_module()
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
for step, data in enumerate(val_loader):
|
| 174 |
+
I_A, I_B, fuse, task = data
|
| 175 |
+
if torch.cuda.is_available():
|
| 176 |
+
I_A = I_A.to(device)
|
| 177 |
+
I_B = I_B.to(device)
|
| 178 |
+
high_quality = fuse.to(device)
|
| 179 |
+
loss_fuction = loss_fuction.to(device)
|
| 180 |
+
|
| 181 |
+
outputs = apply_fusion_4d_with_interpolation(I_A * 255., I_B * 255., lut, Generator_context)
|
| 182 |
+
tv0, mn0 = TV4(lut)
|
| 183 |
+
loss_tv0 = tv0
|
| 184 |
+
loss_mn0 = mn0
|
| 185 |
+
l1 = F.l1_loss(outputs, high_quality)
|
| 186 |
+
loss_ssim = loss_fuction(I_A, I_B, outputs)
|
| 187 |
+
loss_all = l1 + loss_ssim + 0.1 * loss_mn0 + 10.0 * loss_tv0 #+ text_loss + max_loss
|
| 188 |
+
|
| 189 |
+
train_loss += loss_all.item()
|
| 190 |
+
train_loss_l1 += l1.item()
|
| 191 |
+
train_loss_ssim += loss_ssim.item()
|
| 192 |
+
# train_loss_text += text_loss.item()
|
| 193 |
+
# train_loss_max += max_loss.item()
|
| 194 |
+
train_loss_tv0 += loss_tv0.item()
|
| 195 |
+
train_loss_mn0 += loss_mn0.item()
|
| 196 |
+
|
| 197 |
+
return train_loss, train_loss_l1 , train_loss_tv0, train_loss_mn0
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def save_generator_context(generator_context, save_path="generator_context.pth"):
|
| 201 |
+
torch.save(generator_context.state_dict(), save_path)
|
| 202 |
+
print(f"Generator_for_info weights saved to {save_path}")
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
|
| 206 |
+
if os.path.exists("./finetune_lut_exp") is False:
|
| 207 |
+
os.makedirs("./finetune_lut_exp")
|
| 208 |
+
file_name = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 209 |
+
filefold_path = "./finetune_lut_exp/finetune_lut_{}".format(file_name)
|
| 210 |
+
file_log_path = os.path.join(filefold_path, "log")
|
| 211 |
+
os.makedirs(file_log_path)
|
| 212 |
+
|
| 213 |
+
tb_writer = SummaryWriter(log_dir=file_log_path)
|
| 214 |
+
|
| 215 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 216 |
+
|
| 217 |
+
lut_filepath = "ckpts/fine_tuned_lut_original.npy"
|
| 218 |
+
lut_tensor = torch.tensor(np.load(lut_filepath).astype(np.float32), device=DEVICE)
|
| 219 |
+
lut = OptimizableLUT(lut_tensor)
|
| 220 |
+
|
| 221 |
+
context_file = "ckpts/generator_context_original.pth"
|
| 222 |
+
Generator_context = Generator_for_info().to(DEVICE)
|
| 223 |
+
Generator_context.load_state_dict(torch.load(context_file))
|
| 224 |
+
# Generator_context.eval()
|
| 225 |
+
|
| 226 |
+
batch_size = 6
|
| 227 |
+
visible_path = " "
|
| 228 |
+
infrared_path = " "
|
| 229 |
+
train_fusion_path = " "
|
| 230 |
+
test_visible_path = " "
|
| 231 |
+
test_infrared_path = " "
|
| 232 |
+
test_fusion_path = " "
|
| 233 |
+
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
|
| 234 |
+
data_transform = {
|
| 235 |
+
"train": RandomCropPair(size=(128, 128)),
|
| 236 |
+
"val": T.Compose([T.Resize_16(),
|
| 237 |
+
T.ToTensor()])}
|
| 238 |
+
|
| 239 |
+
train_dataset = DistillDataSet(visible_path=visible_path,
|
| 240 |
+
infrared_path=infrared_path,
|
| 241 |
+
other_fuse_path=train_fusion_path,
|
| 242 |
+
phase="train",
|
| 243 |
+
transform=data_transform["train"])
|
| 244 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
| 245 |
+
batch_size=batch_size,
|
| 246 |
+
shuffle=True,
|
| 247 |
+
pin_memory=True,
|
| 248 |
+
num_workers=nw,
|
| 249 |
+
collate_fn=train_dataset.collate_fn)
|
| 250 |
+
|
| 251 |
+
val_dataset = DistillDataSet(visible_path=test_visible_path,
|
| 252 |
+
infrared_path=test_infrared_path,
|
| 253 |
+
other_fuse_path=test_fusion_path,
|
| 254 |
+
phase="val",
|
| 255 |
+
transform=data_transform["val"])
|
| 256 |
+
val_loader = torch.utils.data.DataLoader(val_dataset,
|
| 257 |
+
batch_size=1,
|
| 258 |
+
shuffle=False,
|
| 259 |
+
pin_memory=True,
|
| 260 |
+
num_workers=nw,
|
| 261 |
+
collate_fn=val_dataset.collate_fn)
|
| 262 |
+
|
| 263 |
+
fine_tune_lut(lut, Generator_context, train_loader, val_loader, DEVICE, epochs=496, learning_rate=5e-5)
|
| 264 |
+
|
LUT-Fuse-main/requirements.txt
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.1.0
|
| 2 |
+
asttokens==2.4.1
|
| 3 |
+
astunparse==1.6.3
|
| 4 |
+
backcall==0.2.0
|
| 5 |
+
Brotli
|
| 6 |
+
cachetools==5.5.0
|
| 7 |
+
certifi
|
| 8 |
+
charset-normalizer
|
| 9 |
+
contourpy==1.1.1
|
| 10 |
+
cycler==0.12.1
|
| 11 |
+
decorator==5.1.1
|
| 12 |
+
einops==0.8.0
|
| 13 |
+
et_xmlfile==2.0.0
|
| 14 |
+
executing==2.1.0
|
| 15 |
+
filelock
|
| 16 |
+
flatbuffers==24.3.25
|
| 17 |
+
fonttools==4.54.1
|
| 18 |
+
fsspec==2024.12.0
|
| 19 |
+
ftfy==6.2.3
|
| 20 |
+
gast==0.4.0
|
| 21 |
+
gmpy2
|
| 22 |
+
google-auth==2.35.0
|
| 23 |
+
google-auth-oauthlib==1.0.0
|
| 24 |
+
google-pasta==0.2.0
|
| 25 |
+
grpcio==1.66.2
|
| 26 |
+
h5py==3.11.0
|
| 27 |
+
huggingface-hub==0.27.0
|
| 28 |
+
idna
|
| 29 |
+
imageio==2.35.1
|
| 30 |
+
importlib_metadata==8.5.0
|
| 31 |
+
importlib_resources==6.4.5
|
| 32 |
+
ipython==8.12.3
|
| 33 |
+
jedi==0.19.1
|
| 34 |
+
Jinja2
|
| 35 |
+
keras==2.13.1
|
| 36 |
+
kiwisolver==1.4.7
|
| 37 |
+
lazy_loader==0.4
|
| 38 |
+
libclang==18.1.1
|
| 39 |
+
Markdown==3.7
|
| 40 |
+
MarkupSafe
|
| 41 |
+
matplotlib==3.7.5
|
| 42 |
+
matplotlib-inline==0.1.7
|
| 43 |
+
mkl-fft
|
| 44 |
+
mkl-random
|
| 45 |
+
mkl-service==2.4.0
|
| 46 |
+
mpmath
|
| 47 |
+
networkx
|
| 48 |
+
numpy
|
| 49 |
+
nvidia-ml-py==12.535.161
|
| 50 |
+
nvitop==1.3.2
|
| 51 |
+
oauthlib==3.2.2
|
| 52 |
+
opencv-python==4.9.0.80
|
| 53 |
+
openpyxl==3.1.5
|
| 54 |
+
opt_einsum==3.4.0
|
| 55 |
+
packaging==24.1
|
| 56 |
+
pandas==2.0.3
|
| 57 |
+
parso==0.8.4
|
| 58 |
+
pexpect==4.9.0
|
| 59 |
+
pickleshare==0.7.5
|
| 60 |
+
pillow
|
| 61 |
+
prompt_toolkit==3.0.48
|
| 62 |
+
protobuf==4.25.5
|
| 63 |
+
psutil==6.1.0
|
| 64 |
+
ptyprocess==0.7.0
|
| 65 |
+
pure_eval==0.2.3
|
| 66 |
+
pyasn1==0.6.1
|
| 67 |
+
pyasn1_modules==0.4.1
|
| 68 |
+
Pygments==2.18.0
|
| 69 |
+
pyparsing==3.1.4
|
| 70 |
+
PySocks
|
| 71 |
+
python-dateutil==2.9.0.post0
|
| 72 |
+
pytorch-msssim==1.0.0
|
| 73 |
+
pytz==2024.2
|
| 74 |
+
PyWavelets==1.4.1
|
| 75 |
+
PyYAML==6.0.2
|
| 76 |
+
regex==2024.9.11
|
| 77 |
+
requests
|
| 78 |
+
requests-oauthlib==2.0.0
|
| 79 |
+
rsa==4.9
|
| 80 |
+
safetensors==0.4.5
|
| 81 |
+
scikit-image==0.21.0
|
| 82 |
+
scipy==1.10.1
|
| 83 |
+
seaborn==0.13.2
|
| 84 |
+
six==1.16.0
|
| 85 |
+
stack-data==0.6.3
|
| 86 |
+
sympy
|
| 87 |
+
tensorboard==2.13.0
|
| 88 |
+
tensorboard-data-server==0.7.2
|
| 89 |
+
termcolor==2.4.0
|
| 90 |
+
tifffile==2023.7.10
|
| 91 |
+
tokenizers==0.20.3
|
| 92 |
+
torch==2.0.0
|
| 93 |
+
torchvision==0.15.0
|
| 94 |
+
tqdm==4.66.5
|
| 95 |
+
traitlets==5.14.3
|
| 96 |
+
transformers==4.46.3
|
| 97 |
+
triton==2.0.0
|
| 98 |
+
typing_extensions==4.5.0
|
| 99 |
+
tzdata==2024.2
|
| 100 |
+
urllib3
|
| 101 |
+
wcwidth==0.2.13
|
| 102 |
+
Werkzeug==3.0.4
|
| 103 |
+
wrapt==1.16.0
|
| 104 |
+
zipp==3.20.2
|
LUT-Fuse-main/scripts/calculate.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision.transforms import ToPILImage
|
| 6 |
+
import time
|
| 7 |
+
from data.simple_dataset import RandomCropPair
|
| 8 |
+
from data.simple_dataset import SimpleDataSet
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import transforms as T
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def rgb_to_ycbcr(img):
|
| 15 |
+
return torch.stack(
|
| 16 |
+
(0. / 256. + img[:, 0, :, :] * 0.299000 + img[:, 1, :, :] * 0.587000 + img[:, 2, :, :] * 0.114000,
|
| 17 |
+
128. / 256. - img[:, 0, :, :] * 0.168736 - img[:, 1, :, :] * 0.331264 + img[:, 2, :, :] * 0.500000,
|
| 18 |
+
128. / 256. + img[:, 0, :, :] * 0.500000 - img[:, 1, :, :] * 0.418688 - img[:, 2, :, :] * 0.081312),
|
| 19 |
+
dim=1)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def ycbcr_to_rgb(img):
|
| 23 |
+
return torch.stack(
|
| 24 |
+
(img[:, 0, :, :] + (img[:, 2, :, :] - 0.5) * 1.402,
|
| 25 |
+
img[:, 0, :, :] - (img[:, 1, :, :] - 0.5) * 0.344136 - (img[:, 2, :, :] - 0.5) * 0.714136,
|
| 26 |
+
img[:, 0, :, :] + (img[:, 1, :, :] - 0.5) * 1.772),
|
| 27 |
+
dim=1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_lookup_table(filepath):
|
| 31 |
+
try:
|
| 32 |
+
lut = np.load(filepath).astype(np.float32)
|
| 33 |
+
lut = torch.tensor(lut, device="cuda") # 将查找表移到 GPU 上
|
| 34 |
+
return lut
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"加载查找表时出错: {e}")
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def generator_block(in_filters, out_filters, normalization=False):
|
| 41 |
+
"""Returns downsampling layers of each discriminator block"""
|
| 42 |
+
layers = [nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1)]
|
| 43 |
+
layers.append(nn.LeakyReLU(0.2))
|
| 44 |
+
if normalization:
|
| 45 |
+
layers.append(nn.InstanceNorm2d(out_filters, affine=True))
|
| 46 |
+
# layers.append(nn.BatchNorm2d(out_filters))
|
| 47 |
+
return layers
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Generator_for_info(nn.Module):
|
| 51 |
+
def __init__(self, in_channels=4):
|
| 52 |
+
super(Generator_for_info, self).__init__()
|
| 53 |
+
|
| 54 |
+
self.input_layer = nn.Sequential(
|
| 55 |
+
nn.Conv2d(in_channels, 16, 3, stride=1, padding=1),
|
| 56 |
+
nn.LeakyReLU(0.2),
|
| 57 |
+
nn.InstanceNorm2d(16, affine=True),)
|
| 58 |
+
|
| 59 |
+
self.mid_layer = nn.Sequential(
|
| 60 |
+
*generator_block(16, 16, normalization=True),
|
| 61 |
+
*generator_block(16, 16, normalization=True),
|
| 62 |
+
*generator_block(16, 16, normalization=True),)
|
| 63 |
+
|
| 64 |
+
self.output_layer = nn.Sequential(
|
| 65 |
+
nn.Dropout(p=0.5),
|
| 66 |
+
nn.Conv2d(16, 1, 3, stride=1, padding=1),
|
| 67 |
+
nn.Sigmoid())
|
| 68 |
+
|
| 69 |
+
def forward(self, img_input):
|
| 70 |
+
x = self.input_layer(img_input)
|
| 71 |
+
identity = x
|
| 72 |
+
out = self.mid_layer(x)
|
| 73 |
+
out += identity
|
| 74 |
+
out = self.output_layer(out)
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def apply_fusion_4d_with_interpolation(visible_img, infrared_img, lut, get_context):
|
| 79 |
+
|
| 80 |
+
image_cat = torch.cat((visible_img, infrared_img), dim=1) # [0, 255]
|
| 81 |
+
context = get_context(image_cat)
|
| 82 |
+
|
| 83 |
+
context_scaled = (context *255./ 16.0).squeeze(1) # [0, 16]
|
| 84 |
+
infrared_scaled = infrared_img / 16.0 # [0, 16]
|
| 85 |
+
|
| 86 |
+
ycbcr_vis = rgb_to_ycbcr(visible_img / 255.) # [0, 1]
|
| 87 |
+
ycbcr_vis_scaled = ycbcr_vis * 255.0 / 16.0 # [0, 16]
|
| 88 |
+
|
| 89 |
+
y_vi_scaled = ycbcr_vis_scaled[:, 0, :, :] # [b, 1, h, w] # [0, 16]
|
| 90 |
+
cb_cr = ycbcr_vis[:, 1:, :, :] # [0, 1]
|
| 91 |
+
ir_scaled = infrared_scaled[:, 0, :, :] # [0, 16]
|
| 92 |
+
|
| 93 |
+
# 获取floor和ceil索引
|
| 94 |
+
ir_floor = torch.floor(ir_scaled).long()
|
| 95 |
+
ir_ceil = torch.clamp(ir_floor + 1, 0, lut.shape[3] - 1)
|
| 96 |
+
ir_alpha = ir_scaled - ir_floor
|
| 97 |
+
|
| 98 |
+
y_vi_floor = torch.floor(y_vi_scaled).long()
|
| 99 |
+
y_vi_ceil = torch.clamp(y_vi_floor + 1, 0, lut.shape[0] - 1)
|
| 100 |
+
y_vi_alpha = y_vi_scaled - y_vi_floor
|
| 101 |
+
|
| 102 |
+
c_floor = torch.floor(context_scaled).long()
|
| 103 |
+
c_ceil = torch.clamp(c_floor + 1, 0, lut.shape[2] - 1)
|
| 104 |
+
c_alpha = context_scaled - c_floor
|
| 105 |
+
|
| 106 |
+
sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(
|
| 107 |
+
visible_img.device)
|
| 108 |
+
sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(
|
| 109 |
+
visible_img.device)
|
| 110 |
+
|
| 111 |
+
# 计算 x 和 y 方向的梯度
|
| 112 |
+
grad_x = torch.nn.functional.conv2d(ycbcr_vis[:, :1, :, :], sobel_x, padding=1)
|
| 113 |
+
grad_y = torch.nn.functional.conv2d(ycbcr_vis[:, :1, :, :], sobel_y, padding=1)
|
| 114 |
+
|
| 115 |
+
# 计算梯度幅值
|
| 116 |
+
gradient = torch.sqrt(grad_x ** 2 + grad_y ** 2) # [b, 1, h, w]
|
| 117 |
+
min_val = gradient.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values
|
| 118 |
+
max_val = gradient.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values
|
| 119 |
+
|
| 120 |
+
gradient_normalized = (gradient - min_val) / (max_val - min_val + 1e-8)
|
| 121 |
+
|
| 122 |
+
gradient_scaled = (gradient_normalized * 255.)
|
| 123 |
+
gradient_scaled = (gradient_scaled / 16.0).squeeze(1)
|
| 124 |
+
|
| 125 |
+
g_floor = torch.floor(gradient_scaled).long()
|
| 126 |
+
g_ceil = torch.clamp(g_floor + 1, 0, lut.shape[1] - 1)
|
| 127 |
+
g_alpha = gradient_scaled - g_floor
|
| 128 |
+
|
| 129 |
+
ir_alpha = ir_alpha.unsqueeze(-1)
|
| 130 |
+
y_vi_alpha = y_vi_alpha.unsqueeze(-1)
|
| 131 |
+
g_alpha = g_alpha.unsqueeze(-1)
|
| 132 |
+
c_alpha = c_alpha.unsqueeze(-1)
|
| 133 |
+
|
| 134 |
+
def lerp(v1, v2, alpha):
|
| 135 |
+
out = v1 * (1 - alpha) + v2 * alpha
|
| 136 |
+
return out
|
| 137 |
+
|
| 138 |
+
fusion_result = (
|
| 139 |
+
lerp(
|
| 140 |
+
lerp(
|
| 141 |
+
lerp(
|
| 142 |
+
lerp(lut[y_vi_floor, g_floor, c_floor, ir_floor], lut[y_vi_floor, g_floor, c_floor, ir_ceil],
|
| 143 |
+
ir_alpha),
|
| 144 |
+
lerp(lut[y_vi_floor, g_floor, c_ceil, ir_floor], lut[y_vi_floor, g_floor, c_ceil, ir_ceil],
|
| 145 |
+
ir_alpha),
|
| 146 |
+
c_alpha,
|
| 147 |
+
),
|
| 148 |
+
lerp(
|
| 149 |
+
lerp(lut[y_vi_floor, g_ceil, c_floor, ir_floor], lut[y_vi_floor, g_ceil, c_floor, ir_ceil],
|
| 150 |
+
ir_alpha),
|
| 151 |
+
lerp(lut[y_vi_floor, g_ceil, c_ceil, ir_floor], lut[y_vi_floor, g_ceil, c_ceil, ir_ceil], ir_alpha),
|
| 152 |
+
c_alpha,
|
| 153 |
+
),
|
| 154 |
+
g_alpha,
|
| 155 |
+
),
|
| 156 |
+
lerp(
|
| 157 |
+
lerp(
|
| 158 |
+
lerp(lut[y_vi_ceil, g_floor, c_floor, ir_floor], lut[y_vi_ceil, g_floor, c_floor, ir_ceil],
|
| 159 |
+
ir_alpha),
|
| 160 |
+
lerp(lut[y_vi_ceil, g_floor, c_ceil, ir_floor], lut[y_vi_ceil, g_floor, c_ceil, ir_ceil], ir_alpha),
|
| 161 |
+
c_alpha,
|
| 162 |
+
),
|
| 163 |
+
lerp(
|
| 164 |
+
lerp(lut[y_vi_ceil, g_ceil, c_floor, ir_floor], lut[y_vi_ceil, g_ceil, c_floor, ir_ceil], ir_alpha),
|
| 165 |
+
lerp(lut[y_vi_ceil, g_ceil, c_ceil, ir_floor], lut[y_vi_ceil, g_ceil, c_ceil, ir_ceil], ir_alpha),
|
| 166 |
+
c_alpha,
|
| 167 |
+
),
|
| 168 |
+
g_alpha,
|
| 169 |
+
),
|
| 170 |
+
y_vi_alpha,
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
fusion_y = fusion_result.permute(0, 3, 1, 2)
|
| 175 |
+
|
| 176 |
+
fusion_ycbcr = torch.cat([fusion_y, cb_cr], dim=1)
|
| 177 |
+
fusion_rgb = ycbcr_to_rgb(fusion_ycbcr) # fusion_rgb = fusion_ycbcr.permute(0, 3, 1, 2)
|
| 178 |
+
|
| 179 |
+
return fusion_rgb
|
| 180 |
+
|
LUT-Fuse-main/scripts/loss_lut.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from math import exp
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class fusion_loss(nn.Module):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super(fusion_loss, self).__init__()
|
| 10 |
+
self.loss_func_ssim = L_SSIM(window_size=48)
|
| 11 |
+
self.loss_func_Grad = GradientMaxLoss()
|
| 12 |
+
self.loss_func_l1 = nn.L1Loss() # 添加 L1 损失
|
| 13 |
+
self.loss_func_l2 = nn.MSELoss() # 添加 L2 损失
|
| 14 |
+
self.loss_func_color = L_color()
|
| 15 |
+
self.loss_func_Max = L_Intensity_Max_RGB()
|
| 16 |
+
|
| 17 |
+
def forward(self, image_vi, image_ir, image_fused, max_ratio=4, consist_ratio=1, ssim_ir_ratio=1,
|
| 18 |
+
ssim_ratio=1, ir_compose=1, color_ratio=12, text_ratio=2, max_mode="l1", consist_mode="l1"):
|
| 19 |
+
image_visible_gray = self.rgb2gray(image_vi)
|
| 20 |
+
image_infrared_gray = self.rgb2gray(image_ir)
|
| 21 |
+
image_fused_gray = self.rgb2gray(image_fused)
|
| 22 |
+
# loss_text = text_ratio * self.loss_func_Grad(image_visible_gray, image_infrared_gray, image_fused_gray)
|
| 23 |
+
# loss_max = max_ratio * self.loss_func_Max(image_vi, image_ir, image_fused, max_mode)
|
| 24 |
+
loss_ssim = ssim_ratio * (self.loss_func_ssim(image_vi, image_fused) + ssim_ir_ratio * self.loss_func_ssim(image_ir, image_fused_gray))
|
| 25 |
+
return loss_ssim
|
| 26 |
+
|
| 27 |
+
def rgb2gray(self, image):
|
| 28 |
+
b, c, h, w = image.size()
|
| 29 |
+
if c == 1:
|
| 30 |
+
return image
|
| 31 |
+
image_gray = 0.299 * image[:, 0, :, :] + 0.587 * image[:, 1, :, :] + 0.114 * image[:, 2, :, :]
|
| 32 |
+
image_gray = image_gray.unsqueeze(dim=1)
|
| 33 |
+
return image_gray
|
| 34 |
+
|
| 35 |
+
def gaussian(window_size, sigma):
|
| 36 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
| 37 |
+
return gauss / gauss.sum()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class L_Intensity_Max_RGB(nn.Module):
|
| 41 |
+
def __init__(self):
|
| 42 |
+
super(L_Intensity_Max_RGB, self).__init__()
|
| 43 |
+
|
| 44 |
+
def forward(self, image_visible, image_infrared, image_fused, max_mode="l1"):
|
| 45 |
+
gray_visible = torch.mean(image_visible, dim=1, keepdim=True)
|
| 46 |
+
gray_infrared = torch.mean(image_infrared, dim=1, keepdim=True)
|
| 47 |
+
|
| 48 |
+
mask = (gray_infrared > gray_visible).float()
|
| 49 |
+
|
| 50 |
+
fused_image = mask * image_infrared + (1 - mask) * image_visible
|
| 51 |
+
if max_mode == "l1":
|
| 52 |
+
Loss_intensity = F.l1_loss(fused_image, image_fused)
|
| 53 |
+
else:
|
| 54 |
+
Loss_intensity = F.mse_loss(fused_image, image_fused)
|
| 55 |
+
return Loss_intensity
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def create_window(window_size, channel=1):
|
| 59 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 60 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 61 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
| 62 |
+
return window
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def ssim(img1, img2, window_size=24, window=None, size_average=True, val_range=None):
|
| 66 |
+
|
| 67 |
+
if val_range is None:
|
| 68 |
+
if torch.max(img1) > 128:
|
| 69 |
+
max_val = 255
|
| 70 |
+
else:
|
| 71 |
+
max_val = 1
|
| 72 |
+
|
| 73 |
+
if torch.min(img1) < -0.5:
|
| 74 |
+
min_val = -1
|
| 75 |
+
else:
|
| 76 |
+
min_val = 0
|
| 77 |
+
L = max_val - min_val
|
| 78 |
+
else:
|
| 79 |
+
L = val_range
|
| 80 |
+
|
| 81 |
+
padd = 0
|
| 82 |
+
(_, channel, height, width) = img1.size()
|
| 83 |
+
if window is None:
|
| 84 |
+
real_size = min(window_size, height, width)
|
| 85 |
+
window = create_window(real_size, channel=channel).to(img1.device)
|
| 86 |
+
|
| 87 |
+
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
| 88 |
+
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
| 89 |
+
|
| 90 |
+
mu1_sq = mu1.pow(2)
|
| 91 |
+
mu2_sq = mu2.pow(2)
|
| 92 |
+
mu1_mu2 = mu1 * mu2
|
| 93 |
+
|
| 94 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
|
| 95 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
|
| 96 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
|
| 97 |
+
|
| 98 |
+
C1 = (0.01 * L) ** 2
|
| 99 |
+
C2 = (0.03 * L) ** 2
|
| 100 |
+
|
| 101 |
+
v1 = 2.0 * sigma12 + C2
|
| 102 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
| 103 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
| 104 |
+
|
| 105 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
| 106 |
+
|
| 107 |
+
if size_average:
|
| 108 |
+
ret = ssim_map.mean()
|
| 109 |
+
else:
|
| 110 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
| 111 |
+
|
| 112 |
+
return 1 - ret
|
| 113 |
+
|
| 114 |
+
class GradientMaxLoss(nn.Module):
|
| 115 |
+
def __init__(self):
|
| 116 |
+
super(GradientMaxLoss, self).__init__()
|
| 117 |
+
self.sobel_x = nn.Parameter(torch.FloatTensor([[-1, 0, 1],
|
| 118 |
+
[-2, 0, 2],
|
| 119 |
+
[-1, 0, 1]]).view(1, 1, 3, 3), requires_grad=False).cuda()
|
| 120 |
+
self.sobel_y = nn.Parameter(torch.FloatTensor([[-1, -2, -1],
|
| 121 |
+
[0, 0, 0],
|
| 122 |
+
[1, 2, 1]]).view(1, 1, 3, 3), requires_grad=False).cuda()
|
| 123 |
+
self.padding = (1, 1, 1, 1)
|
| 124 |
+
|
| 125 |
+
def forward(self, image_A, image_B, image_fuse):
|
| 126 |
+
gradient_A_x, gradient_A_y = self.gradient(image_A)
|
| 127 |
+
gradient_B_x, gradient_B_y = self.gradient(image_B)
|
| 128 |
+
gradient_fuse_x, gradient_fuse_y = self.gradient(image_fuse)
|
| 129 |
+
loss = F.l1_loss(gradient_fuse_x, torch.max(gradient_A_x, gradient_B_x)) + F.l1_loss(gradient_fuse_y, torch.max(gradient_A_y, gradient_B_y))
|
| 130 |
+
return loss
|
| 131 |
+
|
| 132 |
+
def gradient(self, image):
|
| 133 |
+
image = F.pad(image, self.padding, mode='replicate')
|
| 134 |
+
gradient_x = F.conv2d(image, self.sobel_x, padding=0)
|
| 135 |
+
gradient_y = F.conv2d(image, self.sobel_y, padding=0)
|
| 136 |
+
return torch.abs(gradient_x), torch.abs(gradient_y)
|
| 137 |
+
|
| 138 |
+
class L_SSIM(torch.nn.Module):
|
| 139 |
+
def __init__(self, window_size=11, size_average=True, val_range=None):
|
| 140 |
+
super(L_SSIM, self).__init__()
|
| 141 |
+
self.window_size = window_size
|
| 142 |
+
self.size_average = size_average
|
| 143 |
+
self.val_range = val_range
|
| 144 |
+
|
| 145 |
+
self.channel = 1
|
| 146 |
+
self.window = create_window(window_size)
|
| 147 |
+
|
| 148 |
+
def forward(self, img1, img2):
|
| 149 |
+
(_, channel, _, _) = img1.size()
|
| 150 |
+
(_, channel_2, _, _) = img2.size()
|
| 151 |
+
|
| 152 |
+
if channel != channel_2 and channel == 1:
|
| 153 |
+
img1 = torch.concat([img1, img1, img1], dim=1)
|
| 154 |
+
channel = 3
|
| 155 |
+
|
| 156 |
+
if channel == self.channel and self.window.dtype == img1.dtype:
|
| 157 |
+
window = self.window.cuda()
|
| 158 |
+
else:
|
| 159 |
+
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
|
| 160 |
+
self.window = window.cuda()
|
| 161 |
+
self.channel = channel
|
| 162 |
+
|
| 163 |
+
return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class L_color(nn.Module):
|
| 167 |
+
def __init__(self):
|
| 168 |
+
super(L_color, self).__init__()
|
| 169 |
+
|
| 170 |
+
def forward(self, image_visible, image_fused):
|
| 171 |
+
ycbcr_visible = self.rgb_to_ycbcr(image_visible)
|
| 172 |
+
ycbcr_fused = self.rgb_to_ycbcr(image_fused)
|
| 173 |
+
|
| 174 |
+
cb_visible = ycbcr_visible[:, 1, :, :]
|
| 175 |
+
cr_visible = ycbcr_visible[:, 2, :, :]
|
| 176 |
+
cb_fused = ycbcr_fused[:, 1, :, :]
|
| 177 |
+
cr_fused = ycbcr_fused[:, 2, :, :]
|
| 178 |
+
|
| 179 |
+
loss_cb = F.l1_loss(cb_visible, cb_fused)
|
| 180 |
+
loss_cr = F.l1_loss(cr_visible, cr_fused)
|
| 181 |
+
|
| 182 |
+
loss_color = loss_cb + loss_cr
|
| 183 |
+
return loss_color
|
| 184 |
+
|
| 185 |
+
def rgb_to_ycbcr(self, image):
|
| 186 |
+
r = image[:, 0, :, :]
|
| 187 |
+
g = image[:, 1, :, :]
|
| 188 |
+
b = image[:, 2, :, :]
|
| 189 |
+
|
| 190 |
+
y = 0.299 * r + 0.587 * g + 0.114 * b
|
| 191 |
+
cb = -0.168736 * r - 0.331264 * g + 0.5 * b
|
| 192 |
+
cr = 0.5 * r - 0.418688 * g - 0.081312 * b
|
| 193 |
+
|
| 194 |
+
ycbcr_image = torch.stack((y, cb, cr), dim=1)
|
| 195 |
+
return ycbcr_image
|
| 196 |
+
|
LUT-Fuse-main/test_lut.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision.transforms import ToPILImage
|
| 6 |
+
import time
|
| 7 |
+
from data.simple_dataset import RandomCropPair
|
| 8 |
+
from data.simple_dataset import SimpleDataSet
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import transforms as T
|
| 11 |
+
import os
|
| 12 |
+
from scripts.calculate import load_lookup_table, Generator_for_info, apply_fusion_4d_with_interpolation
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
lut_filepath = " "
|
| 17 |
+
context_file = " "
|
| 18 |
+
infrared_dir = " "
|
| 19 |
+
visible_dir = " "
|
| 20 |
+
save_dir = " "
|
| 21 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 22 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
|
| 25 |
+
lut = load_lookup_table(lut_filepath).to(device)
|
| 26 |
+
if lut is None:
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
get_context = Generator_for_info()
|
| 30 |
+
get_context.load_state_dict(torch.load(context_file))
|
| 31 |
+
get_context = get_context.to(device)
|
| 32 |
+
get_context.eval()
|
| 33 |
+
|
| 34 |
+
data_transform = {
|
| 35 |
+
"train": RandomCropPair(size=(96, 96)),
|
| 36 |
+
"val": T.Compose([T.Resize_16(),
|
| 37 |
+
T.ToTensor()])}
|
| 38 |
+
|
| 39 |
+
val_dataset = SimpleDataSet(visible_path=visible_dir,
|
| 40 |
+
infrared_path=infrared_dir,
|
| 41 |
+
phase="val",
|
| 42 |
+
transform=data_transform["val"])
|
| 43 |
+
val_loader = torch.utils.data.DataLoader(val_dataset,
|
| 44 |
+
batch_size=1,
|
| 45 |
+
shuffle=False,
|
| 46 |
+
pin_memory=True,
|
| 47 |
+
num_workers=1,
|
| 48 |
+
collate_fn=val_dataset.collate_fn)
|
| 49 |
+
|
| 50 |
+
infrared_files = sorted(os.listdir(infrared_dir))
|
| 51 |
+
visible_files = sorted(os.listdir(visible_dir))
|
| 52 |
+
|
| 53 |
+
assert len(infrared_files) == len(visible_files), "The number of images in the infrared and visible folders do not match!"
|
| 54 |
+
target_size = (128, 128)
|
| 55 |
+
times = []
|
| 56 |
+
|
| 57 |
+
for step, data in enumerate(val_loader):
|
| 58 |
+
I_A, I_B, task = data
|
| 59 |
+
|
| 60 |
+
if torch.cuda.is_available():
|
| 61 |
+
I_A = I_A.to("cuda")
|
| 62 |
+
I_B = I_B.to("cuda")
|
| 63 |
+
|
| 64 |
+
torch.cuda.synchronize()
|
| 65 |
+
start_time = time.time()
|
| 66 |
+
outputs = apply_fusion_4d_with_interpolation(I_A * 255., I_B * 255., lut, get_context)
|
| 67 |
+
torch.cuda.synchronize()
|
| 68 |
+
end_time = time.time()
|
| 69 |
+
elapsed_time = end_time - start_time
|
| 70 |
+
times.append(elapsed_time)
|
| 71 |
+
|
| 72 |
+
if not os.path.splitext(task[0])[1]:
|
| 73 |
+
task_with_extension = task[0] + ".png"
|
| 74 |
+
else:
|
| 75 |
+
task_with_extension = task[0]
|
| 76 |
+
save_path = os.path.join(save_dir, task_with_extension)
|
| 77 |
+
fusion_result = outputs.squeeze(0).clamp(0, 1).cpu()
|
| 78 |
+
fusion_result_image = ToPILImage()(fusion_result)
|
| 79 |
+
fusion_result_image.save(save_path)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
warmup_skip = 25
|
| 83 |
+
if len(times) > warmup_skip:
|
| 84 |
+
times_after_warmup = times[warmup_skip:]
|
| 85 |
+
avg_time = np.mean(times_after_warmup)
|
| 86 |
+
std_time = np.std(times_after_warmup)
|
| 87 |
+
print(f"Processing completed! after skipping the first {warmup_skip} images,avg_time: {avg_time:.4f} seconds,std_time: {std_time:.4f} seconds")
|
| 88 |
+
else:
|
| 89 |
+
print(f"Not enough images to skip the first {warmup_skip} !Total images: {len(times)}")
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
main()
|
| 93 |
+
|
LUT-Fuse-main/transforms.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms as T
|
| 6 |
+
from torchvision.transforms import functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def pad_if_smaller(img, size, fill=0):
|
| 10 |
+
min_size = min(img.size)
|
| 11 |
+
if min_size < size:
|
| 12 |
+
ow, oh = img.size
|
| 13 |
+
padh = size - oh if oh < size else 0
|
| 14 |
+
padw = size - ow if ow < size else 0
|
| 15 |
+
img = F.pad(img, (0, 0, padw, padh), fill=fill)
|
| 16 |
+
return img
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Compose(object):
|
| 20 |
+
def __init__(self, transforms):
|
| 21 |
+
self.transforms = transforms
|
| 22 |
+
|
| 23 |
+
def __call__(self, image):
|
| 24 |
+
for t in self.transforms:
|
| 25 |
+
image= t(image)
|
| 26 |
+
return image
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Resize(object):
|
| 30 |
+
def __init__(self, size):
|
| 31 |
+
self.size = size
|
| 32 |
+
|
| 33 |
+
def __call__(self, image):
|
| 34 |
+
image = F.resize(image, self.size)
|
| 35 |
+
return image
|
| 36 |
+
|
| 37 |
+
class Resize_16(object):
|
| 38 |
+
def __init__(self):
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
def __call__(self, image):
|
| 42 |
+
width, height = image.size
|
| 43 |
+
new_width = (width // 16) * 16
|
| 44 |
+
new_height = (height // 16) * 16
|
| 45 |
+
|
| 46 |
+
image = F.resize(image, (new_height, new_width))
|
| 47 |
+
|
| 48 |
+
return image
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Resize_20(object):
|
| 52 |
+
def __init__(self):
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
def __call__(self, image):
|
| 56 |
+
width, height = image.size
|
| 57 |
+
new_width = (width // 20) * 20
|
| 58 |
+
new_height = (height // 20) * 20
|
| 59 |
+
|
| 60 |
+
image = F.resize(image, (new_height, new_width))
|
| 61 |
+
|
| 62 |
+
return image
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class RandomHorizontalFlip(object):
|
| 66 |
+
def __init__(self, flip_prob):
|
| 67 |
+
self.flip_prob = flip_prob
|
| 68 |
+
|
| 69 |
+
def __call__(self, image):
|
| 70 |
+
if random.random() < self.flip_prob:
|
| 71 |
+
image = F.hflip(image)
|
| 72 |
+
return image
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class RandomVerticalFlip(object):
|
| 76 |
+
def __init__(self, flip_prob):
|
| 77 |
+
self.flip_prob = flip_prob
|
| 78 |
+
|
| 79 |
+
def __call__(self, image):
|
| 80 |
+
if random.random() < self.flip_prob:
|
| 81 |
+
image = F.vflip(image)
|
| 82 |
+
return image
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class RandomCrop(object):
|
| 86 |
+
def __init__(self, size):
|
| 87 |
+
self.size = size
|
| 88 |
+
|
| 89 |
+
def __call__(self, image):
|
| 90 |
+
image = pad_if_smaller(image, self.size)
|
| 91 |
+
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
|
| 92 |
+
image = F.crop(image, *crop_params)
|
| 93 |
+
return image
|
| 94 |
+
|
| 95 |
+
class CenterCrop(object):
|
| 96 |
+
def __init__(self, size):
|
| 97 |
+
self.size = size
|
| 98 |
+
|
| 99 |
+
def __call__(self, image):
|
| 100 |
+
image = F.center_crop(image, self.size)
|
| 101 |
+
|
| 102 |
+
return image
|
| 103 |
+
|
| 104 |
+
class ToTensor(object):
|
| 105 |
+
def __call__(self, image):
|
| 106 |
+
image = F.to_tensor(image)
|
| 107 |
+
return image
|