ZYB5 commited on
Commit
24ebe72
·
verified ·
1 Parent(s): bce38e1

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ LUT-Fuse-main/assets/framework.png filter=lfs diff=lfs merge=lfs -text
LUT-Fuse-main/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Yibing Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LUT-Fuse-main/README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">[ICCV 2025] LUT-Fuse</h1>
2
+ <p align="center">
3
+ <em>Towards Extremely Fast Infrared and Visible Image Fusion via Distillation to Learnable Look-Up Tables</em>
4
+ </p>
5
+
6
+ <p align="center">
7
+ <a href="https://github.com/zyb5/LUT-Fuse" style="text-decoration:none;">
8
+ <img src="https://img.shields.io/badge/GitHub-Code-black?logo=github" alt="Code" />
9
+ </a>
10
+ <a href="https://arxiv.org/abs/2509.00346" style="text-decoration:none; margin-left:8px;">
11
+ <img src="https://img.shields.io/badge/arXiv-Paper-B31B1B?logo=arxiv" alt="Paper" />
12
+ </a>
13
+ </p>
14
+
15
+ <p align="center">
16
+ <img src="assets/framework.png" alt="LUT-Fuse Framework" width="90%">
17
+ </p>
18
+
19
+ ---
20
+
21
+ ## ⚙️ Environment
22
+
23
+ ```
24
+ conda create -n lutfuse python=3.8
25
+ conda activate lutfuse
26
+ ```
27
+
28
+ ```
29
+ conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.8 -c pytorch -c nvidia
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+ ## 📂 Dataset
34
+
35
+ You should list your dataset as followed rule:
36
+
37
+ ```
38
+ |dataset
39
+ |train
40
+ |Infrared
41
+ |Visible
42
+ |Fuse_ref
43
+ |test
44
+ |Infrared
45
+ |Visible
46
+ |Fuse_ref
47
+ ```
48
+
49
+ ## 💾 Checkpoints
50
+
51
+ We provide our **pretrained checkpoints** directly in this repository for convenience.
52
+ You can find them under [`./ckpts`](./ckpts).
53
+
54
+ - **Fusion LUT weights:** `ckpts/fine_tuned_lut.npy`
55
+ - **Context generator weights:** `ckpts/generator_context.pth`
56
+
57
+ ## 🧪 Test
58
+
59
+ ```
60
+ CUDA_VISIBLE_DEVICES=0 python test_lut.py
61
+ ```
62
+
63
+ ## 🚀 Train
64
+
65
+ ```
66
+ CUDA_VISIBLE_DEVICES=0 python fine_tune_lut.py
67
+ ```
68
+
69
+ ## 📖 Citation
70
+
71
+ If you find our work or dataset useful for your research, please cite our paper.
72
+
73
+ ```bibtex
74
+ @inproceedings{yi2025LUT-Fuse,
75
+ title={LUT-Fuse: Towards Extremely Fast Infrared and Visible Image Fusion via Distillation to Learnable Look-Up Tables},
76
+ author={Yi, Xunpeng and Zhang, Yibing and Xiang, Xinyu and Yan, Qinglong and Xu, Han and Ma, Jiayi},
77
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
78
+ year={2025}
79
+ }
80
+ ```
81
+
82
+ If you have any questions, please send an email to zhangyibing@whu.edu.cn
83
+
84
+
LUT-Fuse-main/assets/framework.png ADDED

Git LFS Details

  • SHA256: 7b4dfc0e7e134ba04d9c4b037ab76fba0a5a47b78a299c793baed28c3cb71d0c
  • Pointer size: 131 Bytes
  • Size of remote file: 646 kB
LUT-Fuse-main/ckpts/fine_tuned_lut.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24f70326cb33b27285e5157594d31a79ee94b67b043828cb98a7d60aeec920e4
3
+ size 262272
LUT-Fuse-main/ckpts/fine_tuned_lut_original.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ffa74433659b3a0d8e39ee1f2f4fda2d776ad2967a4e1d9190a67e314099a43
3
+ size 262272
LUT-Fuse-main/ckpts/generator_context.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fcacc0280f6ee37d42ae7b82166484f9fe1660526c647c560e051ff23885324
3
+ size 38143
LUT-Fuse-main/ckpts/generator_context_original.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7a4ec41712435defb95d3c9cfc23839cf79bc00ae22ba47cc5755d480525780
3
+ size 37559
LUT-Fuse-main/data/o_fusion_dataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import os
5
+ from glob import glob
6
+ from torchvision.transforms import RandomCrop
7
+ import torchvision.transforms.functional as F
8
+
9
+ class RandomCropPair:
10
+ def __init__(self, size):
11
+ self.size = size # 裁剪尺寸 (h, w)
12
+
13
+ def __call__(self, vis_img, ir_img, fuse_image):
14
+ # 获取随机裁剪参数
15
+ i, j, h, w = RandomCrop.get_params(vis_img, output_size=self.size)
16
+ # 对可见光和红外图像使用相同的裁剪参数
17
+ vis_img = F.crop(vis_img, i, j, h, w)
18
+ ir_img = F.crop(ir_img, i, j, h, w)
19
+ fuse_image = F.crop(fuse_image, i, j, h, w)
20
+
21
+ vis_img = F.to_tensor(vis_img)
22
+ ir_img = F.to_tensor(ir_img)
23
+ fuse_image = F.to_tensor(fuse_image)
24
+ return vis_img, ir_img, fuse_image
25
+
26
+ class DistillDataSet(Dataset):
27
+ def __init__(self, visible_path, infrared_path, other_fuse_path, phase="train", transform=None):
28
+ self.phase = phase
29
+ self.visible_files = sorted(glob(os.path.join(visible_path, "*.*")))
30
+ self.infrared_files = sorted(glob(os.path.join(infrared_path, "*.*")))
31
+ self.other_fuse_files = sorted(glob(os.path.join(other_fuse_path, "*.*")))
32
+ self.transform = transform
33
+
34
+ def __len__(self):
35
+ l = len(self.infrared_files)
36
+ return l
37
+
38
+ def __getitem__(self, item):
39
+ image_A_path = self.visible_files[item]
40
+ image_B_path = self.infrared_files[item]
41
+ other_fuse_path = self.other_fuse_files[item]
42
+ image_A = Image.open(image_A_path).convert(mode='RGB')
43
+ image_B = Image.open(image_B_path).convert(mode='L') ##########
44
+ other_fuse = Image.open(other_fuse_path).convert(mode='RGB')
45
+
46
+ if self.transform is not None:
47
+ if isinstance(self.transform, RandomCropPair):
48
+ image_A, image_B, other_fuse = self.transform(image_A, image_B, other_fuse)
49
+ else:
50
+ image_A = self.transform(image_A)
51
+ image_B = self.transform(image_B)
52
+ other_fuse = self.transform(other_fuse)
53
+
54
+ name = image_A_path.replace("\\", "/").split("/")[-1].split(".")[0]
55
+
56
+ return image_A, image_B, other_fuse, name
57
+
58
+ @staticmethod
59
+ def collate_fn(batch):
60
+ images_A, images_B, other_fuse, name = zip(*batch)
61
+ images_A = torch.stack(images_A, dim=0)
62
+ images_B = torch.stack(images_B, dim=0)
63
+ other_fuse = torch.stack(other_fuse, dim=0)
64
+ return images_A, images_B, other_fuse, name
LUT-Fuse-main/data/simple_dataset.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import os
5
+ from glob import glob
6
+ import transforms as T
7
+ from torchvision.transforms import RandomCrop
8
+ import torchvision.transforms.functional as F
9
+
10
+ # class SimpleDataSet(Dataset):
11
+ # def __init__(self, visible_path, visible_gt_path, infrared_path, phase="train", transform=None):
12
+ # self.phase = phase
13
+ # self.visible_files = sorted(glob(os.path.join(visible_path, "*")))
14
+ # self.visible_gt_files = sorted(glob(os.path.join(visible_gt_path, "*")))
15
+ # self.infrared_files = sorted(glob(os.path.join(infrared_path, "*")))
16
+ # self.transform = transform
17
+ #
18
+ # def __len__(self):
19
+ # l = len(self.infrared_files)
20
+ # return l
21
+ #
22
+ # def __getitem__(self, item):
23
+ # image_A_path = self.visible_files[item]
24
+ # image_A_gt_path = self.visible_gt_files[item]
25
+ # image_B_path = self.infrared_files[item]
26
+ # image_A = Image.open(image_A_path).convert(mode='RGB')
27
+ # image_A_gt = Image.open(image_A_gt_path).convert(mode='RGB')
28
+ # image_B = Image.open(image_B_path).convert(mode='L') ##########
29
+ #
30
+ # image_A = self.transform(image_A)
31
+ # image_A_gt = self.transform(image_A_gt)
32
+ # image_B = self.transform(image_B)
33
+ #
34
+ # name = image_A_path.replace("\\", "/").split("/")[-1].split(".")[0]
35
+ #
36
+ # return image_A, image_A_gt, image_B, name
37
+ #
38
+ # @staticmethod
39
+ # def collate_fn(batch):
40
+ # images_A, image_A_gt, images_B, name = zip(*batch)
41
+ # images_A = torch.stack(images_A, dim=0)
42
+ # image_A_gt = torch.stack(image_A_gt, dim=0)
43
+ # images_B = torch.stack(images_B, dim=0)
44
+ # return images_A, image_A_gt, images_B, name
45
+
46
+ # class RandomCropPair:
47
+ # def __init__(self, size):
48
+ # self.size = size # 裁剪尺寸 (h, w)
49
+ #
50
+ # def __call__(self, vis_img, ir_img):
51
+ # # 获取随机裁剪参数
52
+ # i, j, h, w = RandomCrop.get_params(vis_img, output_size=self.size)
53
+ # # 对可见光和红外图像使用相同的裁剪参数
54
+ # vis_img = F.crop(vis_img, i, j, h, w)
55
+ # ir_img = F.crop(ir_img, i, j, h, w)
56
+ #
57
+ # vis_img = F.to_tensor(vis_img)
58
+ # ir_img = F.to_tensor(ir_img)
59
+ # return vis_img, ir_img
60
+ #
61
+ # class SimpleDataSet(Dataset):
62
+ # def __init__(self, visible_path, infrared_path, phase="train", transform=None):
63
+ # self.phase = phase
64
+ # self.visible_files = sorted(glob(os.path.join(visible_path, "*.*")))
65
+ # self.infrared_files = sorted(glob(os.path.join(infrared_path, "*.*")))
66
+ # self.transform = transform
67
+ #
68
+ # def __len__(self):
69
+ # l = len(self.infrared_files)
70
+ # return l
71
+ #
72
+ # def __getitem__(self, item):
73
+ # image_A_path = self.visible_files[item]
74
+ # image_B_path = self.infrared_files[item]
75
+ # image_A = Image.open(image_A_path).convert(mode='RGB')
76
+ # image_B = Image.open(image_B_path).convert(mode='L') ##########
77
+ #
78
+ # # image_A = self.transform(image_A)
79
+ # # image_B = self.transform(image_B)
80
+ #
81
+ # if self.transform is not None:
82
+ # if isinstance(self.transform, RandomCropPair):
83
+ # image_A, image_B = self.transform(image_A, image_B)
84
+ # else:
85
+ # image_A = self.transform(image_A)
86
+ # image_B = self.transform(image_B)
87
+ #
88
+ # name = image_A_path.replace("\\", "/").split("/")[-1].split(".")[0]
89
+ #
90
+ # return image_A, image_B, name
91
+ #
92
+ # @staticmethod
93
+ # def collate_fn(batch):
94
+ # images_A, images_B, name = zip(*batch)
95
+ # images_A = torch.stack(images_A, dim=0)
96
+ # images_B = torch.stack(images_B, dim=0)
97
+ # return images_A, images_B, name
98
+ #
99
+ #
100
+ # class RandomCropPair:
101
+ # def __init__(self, size):
102
+ # self.size = size # 裁剪尺寸 (h, w)
103
+ #
104
+ # def __call__(self, vis_blur_img, ir_blur_img, vis_gt_img, ir_gt_img):
105
+ # # 获取随机裁剪参数
106
+ # i, j, h, w = RandomCrop.get_params(vis_blur_img, output_size=self.size)
107
+ # # 对可见光和红外图像使用相同的裁剪参数
108
+ # vis_blur_img = F.crop(vis_blur_img, i, j, h, w)
109
+ # ir_blur_img = F.crop(ir_blur_img, i, j, h, w)
110
+ # vis_gt_img = F.crop(vis_gt_img, i, j, h, w)
111
+ # ir_gt_img = F.crop(ir_gt_img, i, j, h, w)
112
+ #
113
+ # vis_blur_img = F.to_tensor(vis_blur_img)
114
+ # ir_blur_img = F.to_tensor(ir_blur_img)
115
+ # vis_gt_img = F.to_tensor(vis_gt_img)
116
+ # ir_gt_img = F.to_tensor(ir_gt_img)
117
+ # return vis_blur_img, ir_blur_img, vis_gt_img, ir_gt_img
118
+ #
119
+ # class SimpleDataSet(Dataset):
120
+ # def __init__(self, visible_blur_path, infrared_blur_path, visible_gt_path, infrared_gt_path, phase="train", transform=None):
121
+ # self.phase = phase
122
+ # self.visible_blur_files = sorted(glob(os.path.join(visible_blur_path, "*.*")))
123
+ # self.infrared_blur_files = sorted(glob(os.path.join(infrared_blur_path, "*.*")))
124
+ # self.visible_gt_files = sorted(glob(os.path.join(visible_gt_path, "*.*")))
125
+ # self.infrared_gt_files = sorted(glob(os.path.join(infrared_gt_path, "*.*")))
126
+ # self.transform = transform
127
+ #
128
+ # def __len__(self):
129
+ # l = len(self.infrared_gt_files)
130
+ # return l
131
+ #
132
+ # def __getitem__(self, item):
133
+ # image_A_blur_path = self.visible_blur_files[item]
134
+ # image_B_blur_path = self.infrared_blur_files[item]
135
+ # image_A_gt_path = self.visible_gt_files[item]
136
+ # image_B_gt_path = self.infrared_gt_files[item]
137
+ # image_A_blur = Image.open(image_A_blur_path).convert(mode='RGB')
138
+ # image_B_blur = Image.open(image_B_blur_path).convert(mode='L') ##########
139
+ # image_A_gt = Image.open(image_A_gt_path).convert(mode='RGB')
140
+ # image_B_gt = Image.open(image_B_gt_path).convert(mode='L') ##########
141
+ #
142
+ # if self.transform is not None:
143
+ # if isinstance(self.transform, RandomCropPair):
144
+ # image_A_blur, image_B_blur, image_A_gt, image_B_gt = self.transform(image_A_blur, image_B_blur, image_A_gt, image_B_gt)
145
+ # else:
146
+ # image_A_blur = self.transform(image_A_blur)
147
+ # image_B_blur = self.transform(image_B_blur)
148
+ # image_A_gt = self.transform(image_A_gt)
149
+ # image_B_gt = self.transform(image_B_gt)
150
+ #
151
+ # name = image_A_blur_path.replace("\\", "/").split("/")[-1].split(".")[0]
152
+ #
153
+ # return image_A_blur, image_B_blur, image_A_gt, image_B_gt, name
154
+ #
155
+ # @staticmethod
156
+ # def collate_fn(batch):
157
+ # image_A_blur, image_B_blur, image_A_gt, image_B_gt, name = zip(*batch)
158
+ # image_A_blur = torch.stack(image_A_blur, dim=0)
159
+ # image_B_blur = torch.stack(image_B_blur, dim=0)
160
+ # image_A_gt = torch.stack(image_A_gt, dim=0)
161
+ # image_B_gt = torch.stack(image_B_gt, dim=0)
162
+ # return image_A_blur, image_B_blur, image_A_gt, image_B_gt, name
163
+
164
+
165
+ class RandomCropPair:
166
+ def __init__(self, size):
167
+ self.size = size # 裁剪尺寸 (h, w)
168
+
169
+ def __call__(self, vis_img, ir_img):
170
+ # 获取随机裁剪参数
171
+ i, j, h, w = RandomCrop.get_params(vis_img, output_size=self.size)
172
+ # 对可见光和红外图像使用相同的裁剪参数
173
+ vis_img = F.crop(vis_img, i, j, h, w)
174
+ ir_img = F.crop(ir_img, i, j, h, w)
175
+
176
+ vis_img = F.to_tensor(vis_img)
177
+ ir_img = F.to_tensor(ir_img)
178
+ return vis_img, ir_img
179
+
180
+ class SimpleDataSet(Dataset):
181
+ def __init__(self, visible_path, infrared_path, phase="train", transform=None):
182
+ self.phase = phase
183
+ self.visible_files = sorted(glob(os.path.join(visible_path, "*.*")))
184
+ self.infrared_files = sorted(glob(os.path.join(infrared_path, "*.*")))
185
+ self.transform = transform
186
+
187
+ def __len__(self):
188
+ l = len(self.infrared_files)
189
+ return l
190
+
191
+ def __getitem__(self, item):
192
+ image_A_path = self.visible_files[item]
193
+ image_B_path = self.infrared_files[item]
194
+ image_A = Image.open(image_A_path).convert(mode='RGB')
195
+ image_B = Image.open(image_B_path).convert(mode='L') ##########
196
+
197
+ # image_A = self.transform(image_A)
198
+ # image_B = self.transform(image_B)
199
+
200
+ if self.transform is not None:
201
+ if isinstance(self.transform, RandomCropPair):
202
+ image_A, image_B = self.transform(image_A, image_B)
203
+ else:
204
+ image_A = self.transform(image_A)
205
+ image_B = self.transform(image_B)
206
+
207
+ name = image_A_path.replace("\\", "/").split("/")[-1].split(".")[0]
208
+
209
+ return image_A, image_B, name
210
+
211
+ @staticmethod
212
+ def collate_fn(batch):
213
+ images_A, images_B, name = zip(*batch)
214
+ images_A = torch.stack(images_A, dim=0)
215
+ images_B = torch.stack(images_B, dim=0)
216
+ return images_A, images_B, name
217
+
218
+
LUT-Fuse-main/fine_tune_lut.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ from torch.utils.tensorboard import SummaryWriter
7
+
8
+ from data.o_fusion_dataset import DistillDataSet
9
+ from data.o_fusion_dataset import RandomCropPair
10
+ import datetime
11
+ import os
12
+
13
+ import transforms as T
14
+ from scripts.loss_lut import fusion_loss
15
+ from itertools import chain
16
+ from scripts.calculate import OptimizableLUT, Generator_for_info, apply_fusion_4d_with_interpolation
17
+
18
+ cuda = True if torch.cuda.is_available() else False
19
+ Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
20
+
21
+
22
+ class TV_4D(nn.Module):
23
+ def __init__(self, dim=16, output_channels=3):
24
+ super(TV_4D, self).__init__()
25
+
26
+ self.weight_r = torch.ones(dim, dim, dim, dim - 1, output_channels, dtype=torch.float)
27
+ self.weight_r[..., (0, dim - 2), :] *= 2.0
28
+
29
+ self.weight_g = torch.ones(dim, dim, dim - 1, dim, output_channels, dtype=torch.float)
30
+ self.weight_g[..., (0, dim - 2), :, :] *= 2.0
31
+
32
+ self.weight_b = torch.ones(dim, dim - 1, dim, dim, output_channels, dtype=torch.float)
33
+ self.weight_b[..., (0, dim - 2), :, :, :] *= 2.0
34
+
35
+ self.weight_ir = torch.ones(dim - 1, dim, dim, dim, output_channels, dtype=torch.float)
36
+ self.weight_ir[(0, dim - 2), :, :, :, :] *= 2.0
37
+
38
+ self.relu = torch.nn.ReLU()
39
+
40
+ def forward(self, LUT):
41
+ device = LUT.device
42
+
43
+ self.weight_r = self.weight_r.to(device)
44
+
45
+ self.weight_g = self.weight_g.to(device)
46
+ self.weight_b = self.weight_b.to(device)
47
+ self.weight_ir = self.weight_ir.to(device)
48
+
49
+ dif_r = LUT[ :, :, :, :-1, :] - LUT[ :, :, :, 1:, :]
50
+ dif_g = LUT[ :, :, :-1, :, :] - LUT[ :, :, 1:, :, :]
51
+ dif_b = LUT[ :, :-1, :, :, :] - LUT[ :, 1:, :, :, :]
52
+ dif_ir = LUT[ :-1, :, :, :, :] - LUT[ 1:, :, :, :, :]
53
+
54
+ tv = (torch.mean(torch.mul(dif_r ** 2, self.weight_r)) + torch.mean(torch.mul(dif_g ** 2, self.weight_g)) +
55
+ torch.mean(torch.mul(dif_b ** 2, self.weight_b)) + torch.mean(torch.mul(dif_ir ** 2, self.weight_ir)))
56
+
57
+ mn = (torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) +
58
+ torch.mean(self.relu(dif_b)) + torch.mean(self.relu(dif_ir)))
59
+
60
+ return tv, mn
61
+
62
+
63
+ def fine_tune_lut(lut_model, Generator_context, train_loader, val_loader, device, epochs, learning_rate, save_dir="ww"):
64
+ TV4 = TV_4D().to(device)
65
+ best_val_loss = 1e5
66
+ Generator_context.train()
67
+ loss_fuction = fusion_loss()
68
+ optimizer = optim.Adam(chain(lut_model.parameters(), Generator_context.parameters()), lr=learning_rate)
69
+ # optimizer = optim.Adam(lut_model.parameters(), lr=learning_rate)
70
+
71
+ for epoch in range(epochs):
72
+ lut_model.train()
73
+
74
+ train_loss = 0
75
+ # train_loss_max = 0
76
+ # train_loss_text = 0
77
+ train_loss_l1 = 0
78
+ train_loss_ssim = 0
79
+ train_loss_tv0 = 0
80
+ train_loss_mn0 = 0
81
+
82
+ for step, data in enumerate(train_loader):
83
+ I_A, I_B, fuse, _ = data
84
+ # optimizer.zero_grad()
85
+
86
+ if torch.cuda.is_available():
87
+ I_A = I_A.to(device)
88
+ I_B = I_B.to(device)
89
+ high_quality = fuse.to(device)
90
+ loss_fuction = loss_fuction.to(device)
91
+
92
+ lut = lut_model()
93
+
94
+ tv0, mn0 = TV4(lut)
95
+ loss_tv0 = tv0
96
+ loss_mn0 = mn0
97
+
98
+ outputs = apply_fusion_4d_with_interpolation(I_A * 255., I_B * 255., lut, Generator_context)
99
+
100
+ l1 = F.l1_loss(outputs, high_quality)
101
+ ssim = loss_fuction(I_A, I_B, outputs)
102
+ loss_all = l1 + ssim + 10.0 * loss_mn0 + 0.0001 * loss_tv0 #+ text_loss + loss_max
103
+
104
+ loss_all.backward()
105
+ optimizer.step()
106
+
107
+ train_loss += loss_all.item()
108
+ train_loss_l1 += l1.item()
109
+ train_loss_ssim += ssim.item()
110
+ # train_loss_text += text_loss.item()
111
+ # train_loss_max += loss_max.item()
112
+ train_loss_tv0 += loss_tv0.item()
113
+ train_loss_mn0 += loss_mn0.item()
114
+ # train_loss_color += loss_color.item()
115
+
116
+ tb_writer.add_scalar("train_total_loss", train_loss/len(train_loader), epoch)
117
+ tb_writer.add_scalar("train_loss_l1", train_loss_l1/len(train_loader), epoch)
118
+ tb_writer.add_scalar("train_loss_ssim", train_loss_ssim / len(train_loader), epoch)
119
+ # tb_writer.add_scalar("train_loss_text", train_loss_text / len(train_loader), epoch)
120
+ # tb_writer.add_scalar("train_loss_max", train_loss_max / len(train_loader), epoch)
121
+ tb_writer.add_scalar("train_loss_tv0", train_loss_tv0/len(train_loader), epoch)
122
+ tb_writer.add_scalar("train_loss_mn0", train_loss_mn0/len(train_loader), epoch)
123
+
124
+ print(f"Epoch {epoch + 1}/{epochs} - Loss: {train_loss / len(train_loader):.6f} - loss_l1: {train_loss_l1 / len(train_loader):.6f} - loss_ssim: {train_loss_ssim / len(train_loader):.6f} - loss_tv: {train_loss_tv0 / len(train_loader):.6f} - loss_tv: {train_loss_tv0 / len(train_loader):.6f} ")
125
+ # print(f"Epoch {epoch + 1}/{epochs} - Loss: {train_loss / len(train_loader):.6f} - l1: {train_loss_l1 / len(train_loader):.6f} - loss_text: {train_loss_text / len(train_loader):.6f} - loss_max: {train_loss_max / len(train_loader):.6f}")
126
+ # print(f"Epoch {epoch + 1}/{epochs} - Loss: {train_loss / len(train_loader):.6f} - l1: {train_loss_l1 / len(train_loader):.6f} - tv: {train_loss_tv0 / len(train_loader):.6f} - mn: {train_loss_mn0 / len(train_loader):.6f}")
127
+
128
+ if (epoch + 1) % 10 == 0:
129
+ val_loss, val_loss_l1, val_loss_ssim, val_loss_tv0, val_loss_mn0 = validate_lut(lut_model, val_loader, device)
130
+ tb_writer.add_scalar("val_total_loss", val_loss/len(val_loader), epoch)
131
+ tb_writer.add_scalar("val_loss_l1", val_loss_l1/len(val_loader), epoch)
132
+ tb_writer.add_scalar("val_loss_ssim", val_loss_ssim / len(val_loader), epoch)
133
+ # tb_writer.add_scalar("val_loss_text", val_loss_text / len(val_loader), epoch)
134
+ # tb_writer.add_scalar("val_loss_max", val_loss_max / len(val_loader), epoch)
135
+ tb_writer.add_scalar("val_loss_tv0", val_loss_tv0/len(val_loader), epoch)
136
+ tb_writer.add_scalar("val_loss_mn0", val_loss_mn0/len(val_loader), epoch)
137
+ # print(f"Validation - Epoch {epoch} - Loss: {val_loss / len(val_loader):.6f} - l1: {val_loss_l1 / len(val_loader):.6f} - tv: {val_loss_tv0 / len(val_loader):.6f} - mn: {val_loss_mn0 / len(val_loader):.6f}")
138
+
139
+ if val_loss < best_val_loss :
140
+ best_val_loss = val_loss
141
+ filename = f"fine_tuned_ygcy_epoch{epoch}_valloss{val_loss:.6f}.npy"
142
+ full_path = os.path.join(filefold_path, filename)
143
+ save_lut(lut_model, full_path)
144
+
145
+ context_filename = f"generator_context_epoch{epoch}_valloss{val_loss:.6f}.pth"
146
+ generator_context_save_path = os.path.join(filefold_path, context_filename)
147
+ save_generator_context(Generator_context, save_path=generator_context_save_path)
148
+
149
+ print(f"Validation - Epoch {epoch} - Loss: {val_loss / len(val_loader):.6f} - l1: {val_loss_l1 / len(val_loader):.6f} - loss_ssim: {val_loss_ssim / len(val_loader):.6f} - loss_tv0: {val_loss_tv0 / len(val_loader):.6f} - loss_mn0: {val_loss_mn0 / len(val_loader):.6f}")
150
+
151
+
152
+ def save_lut(lut_module, path):
153
+
154
+ lut_weights = lut_module().detach().cpu().numpy()
155
+ np.save(path, lut_weights)
156
+ print(f"Fine-tuned LUT saved to {path}")
157
+
158
+
159
+ def validate_lut(lut_module, val_loader, device):
160
+ train_loss = 0
161
+ train_loss_mn0 = 0
162
+ train_loss_tv0 = 0
163
+ train_loss_ssim = 0
164
+ train_loss_l1 = 0
165
+ # train_loss_text = 0
166
+ TV4 = TV_4D().to(device)
167
+
168
+ loss_fuction = fusion_loss()
169
+ Generator_context.eval()
170
+
171
+ lut = lut_module()
172
+ with torch.no_grad():
173
+ for step, data in enumerate(val_loader):
174
+ I_A, I_B, fuse, task = data
175
+ if torch.cuda.is_available():
176
+ I_A = I_A.to(device)
177
+ I_B = I_B.to(device)
178
+ high_quality = fuse.to(device)
179
+ loss_fuction = loss_fuction.to(device)
180
+
181
+ outputs = apply_fusion_4d_with_interpolation(I_A * 255., I_B * 255., lut, Generator_context)
182
+ tv0, mn0 = TV4(lut)
183
+ loss_tv0 = tv0
184
+ loss_mn0 = mn0
185
+ l1 = F.l1_loss(outputs, high_quality)
186
+ loss_ssim = loss_fuction(I_A, I_B, outputs)
187
+ loss_all = l1 + loss_ssim + 0.1 * loss_mn0 + 10.0 * loss_tv0 #+ text_loss + max_loss
188
+
189
+ train_loss += loss_all.item()
190
+ train_loss_l1 += l1.item()
191
+ train_loss_ssim += loss_ssim.item()
192
+ # train_loss_text += text_loss.item()
193
+ # train_loss_max += max_loss.item()
194
+ train_loss_tv0 += loss_tv0.item()
195
+ train_loss_mn0 += loss_mn0.item()
196
+
197
+ return train_loss, train_loss_l1 , train_loss_tv0, train_loss_mn0
198
+
199
+
200
+ def save_generator_context(generator_context, save_path="generator_context.pth"):
201
+ torch.save(generator_context.state_dict(), save_path)
202
+ print(f"Generator_for_info weights saved to {save_path}")
203
+
204
+ if __name__ == "__main__":
205
+
206
+ if os.path.exists("./finetune_lut_exp") is False:
207
+ os.makedirs("./finetune_lut_exp")
208
+ file_name = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
209
+ filefold_path = "./finetune_lut_exp/finetune_lut_{}".format(file_name)
210
+ file_log_path = os.path.join(filefold_path, "log")
211
+ os.makedirs(file_log_path)
212
+
213
+ tb_writer = SummaryWriter(log_dir=file_log_path)
214
+
215
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
216
+
217
+ lut_filepath = "ckpts/fine_tuned_lut_original.npy"
218
+ lut_tensor = torch.tensor(np.load(lut_filepath).astype(np.float32), device=DEVICE)
219
+ lut = OptimizableLUT(lut_tensor)
220
+
221
+ context_file = "ckpts/generator_context_original.pth"
222
+ Generator_context = Generator_for_info().to(DEVICE)
223
+ Generator_context.load_state_dict(torch.load(context_file))
224
+ # Generator_context.eval()
225
+
226
+ batch_size = 6
227
+ visible_path = " "
228
+ infrared_path = " "
229
+ train_fusion_path = " "
230
+ test_visible_path = " "
231
+ test_infrared_path = " "
232
+ test_fusion_path = " "
233
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
234
+ data_transform = {
235
+ "train": RandomCropPair(size=(128, 128)),
236
+ "val": T.Compose([T.Resize_16(),
237
+ T.ToTensor()])}
238
+
239
+ train_dataset = DistillDataSet(visible_path=visible_path,
240
+ infrared_path=infrared_path,
241
+ other_fuse_path=train_fusion_path,
242
+ phase="train",
243
+ transform=data_transform["train"])
244
+ train_loader = torch.utils.data.DataLoader(train_dataset,
245
+ batch_size=batch_size,
246
+ shuffle=True,
247
+ pin_memory=True,
248
+ num_workers=nw,
249
+ collate_fn=train_dataset.collate_fn)
250
+
251
+ val_dataset = DistillDataSet(visible_path=test_visible_path,
252
+ infrared_path=test_infrared_path,
253
+ other_fuse_path=test_fusion_path,
254
+ phase="val",
255
+ transform=data_transform["val"])
256
+ val_loader = torch.utils.data.DataLoader(val_dataset,
257
+ batch_size=1,
258
+ shuffle=False,
259
+ pin_memory=True,
260
+ num_workers=nw,
261
+ collate_fn=val_dataset.collate_fn)
262
+
263
+ fine_tune_lut(lut, Generator_context, train_loader, val_loader, DEVICE, epochs=496, learning_rate=5e-5)
264
+
LUT-Fuse-main/requirements.txt ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ asttokens==2.4.1
3
+ astunparse==1.6.3
4
+ backcall==0.2.0
5
+ Brotli
6
+ cachetools==5.5.0
7
+ certifi
8
+ charset-normalizer
9
+ contourpy==1.1.1
10
+ cycler==0.12.1
11
+ decorator==5.1.1
12
+ einops==0.8.0
13
+ et_xmlfile==2.0.0
14
+ executing==2.1.0
15
+ filelock
16
+ flatbuffers==24.3.25
17
+ fonttools==4.54.1
18
+ fsspec==2024.12.0
19
+ ftfy==6.2.3
20
+ gast==0.4.0
21
+ gmpy2
22
+ google-auth==2.35.0
23
+ google-auth-oauthlib==1.0.0
24
+ google-pasta==0.2.0
25
+ grpcio==1.66.2
26
+ h5py==3.11.0
27
+ huggingface-hub==0.27.0
28
+ idna
29
+ imageio==2.35.1
30
+ importlib_metadata==8.5.0
31
+ importlib_resources==6.4.5
32
+ ipython==8.12.3
33
+ jedi==0.19.1
34
+ Jinja2
35
+ keras==2.13.1
36
+ kiwisolver==1.4.7
37
+ lazy_loader==0.4
38
+ libclang==18.1.1
39
+ Markdown==3.7
40
+ MarkupSafe
41
+ matplotlib==3.7.5
42
+ matplotlib-inline==0.1.7
43
+ mkl-fft
44
+ mkl-random
45
+ mkl-service==2.4.0
46
+ mpmath
47
+ networkx
48
+ numpy
49
+ nvidia-ml-py==12.535.161
50
+ nvitop==1.3.2
51
+ oauthlib==3.2.2
52
+ opencv-python==4.9.0.80
53
+ openpyxl==3.1.5
54
+ opt_einsum==3.4.0
55
+ packaging==24.1
56
+ pandas==2.0.3
57
+ parso==0.8.4
58
+ pexpect==4.9.0
59
+ pickleshare==0.7.5
60
+ pillow
61
+ prompt_toolkit==3.0.48
62
+ protobuf==4.25.5
63
+ psutil==6.1.0
64
+ ptyprocess==0.7.0
65
+ pure_eval==0.2.3
66
+ pyasn1==0.6.1
67
+ pyasn1_modules==0.4.1
68
+ Pygments==2.18.0
69
+ pyparsing==3.1.4
70
+ PySocks
71
+ python-dateutil==2.9.0.post0
72
+ pytorch-msssim==1.0.0
73
+ pytz==2024.2
74
+ PyWavelets==1.4.1
75
+ PyYAML==6.0.2
76
+ regex==2024.9.11
77
+ requests
78
+ requests-oauthlib==2.0.0
79
+ rsa==4.9
80
+ safetensors==0.4.5
81
+ scikit-image==0.21.0
82
+ scipy==1.10.1
83
+ seaborn==0.13.2
84
+ six==1.16.0
85
+ stack-data==0.6.3
86
+ sympy
87
+ tensorboard==2.13.0
88
+ tensorboard-data-server==0.7.2
89
+ termcolor==2.4.0
90
+ tifffile==2023.7.10
91
+ tokenizers==0.20.3
92
+ torch==2.0.0
93
+ torchvision==0.15.0
94
+ tqdm==4.66.5
95
+ traitlets==5.14.3
96
+ transformers==4.46.3
97
+ triton==2.0.0
98
+ typing_extensions==4.5.0
99
+ tzdata==2024.2
100
+ urllib3
101
+ wcwidth==0.2.13
102
+ Werkzeug==3.0.4
103
+ wrapt==1.16.0
104
+ zipp==3.20.2
LUT-Fuse-main/scripts/calculate.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision.transforms import ToPILImage
6
+ import time
7
+ from data.simple_dataset import RandomCropPair
8
+ from data.simple_dataset import SimpleDataSet
9
+ import torch.nn as nn
10
+ import transforms as T
11
+ import os
12
+
13
+
14
+ def rgb_to_ycbcr(img):
15
+ return torch.stack(
16
+ (0. / 256. + img[:, 0, :, :] * 0.299000 + img[:, 1, :, :] * 0.587000 + img[:, 2, :, :] * 0.114000,
17
+ 128. / 256. - img[:, 0, :, :] * 0.168736 - img[:, 1, :, :] * 0.331264 + img[:, 2, :, :] * 0.500000,
18
+ 128. / 256. + img[:, 0, :, :] * 0.500000 - img[:, 1, :, :] * 0.418688 - img[:, 2, :, :] * 0.081312),
19
+ dim=1)
20
+
21
+
22
+ def ycbcr_to_rgb(img):
23
+ return torch.stack(
24
+ (img[:, 0, :, :] + (img[:, 2, :, :] - 0.5) * 1.402,
25
+ img[:, 0, :, :] - (img[:, 1, :, :] - 0.5) * 0.344136 - (img[:, 2, :, :] - 0.5) * 0.714136,
26
+ img[:, 0, :, :] + (img[:, 1, :, :] - 0.5) * 1.772),
27
+ dim=1)
28
+
29
+
30
+ def load_lookup_table(filepath):
31
+ try:
32
+ lut = np.load(filepath).astype(np.float32)
33
+ lut = torch.tensor(lut, device="cuda") # 将查找表移到 GPU 上
34
+ return lut
35
+ except Exception as e:
36
+ print(f"加载查找表时出错: {e}")
37
+ return None
38
+
39
+
40
+ def generator_block(in_filters, out_filters, normalization=False):
41
+ """Returns downsampling layers of each discriminator block"""
42
+ layers = [nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1)]
43
+ layers.append(nn.LeakyReLU(0.2))
44
+ if normalization:
45
+ layers.append(nn.InstanceNorm2d(out_filters, affine=True))
46
+ # layers.append(nn.BatchNorm2d(out_filters))
47
+ return layers
48
+
49
+
50
+ class Generator_for_info(nn.Module):
51
+ def __init__(self, in_channels=4):
52
+ super(Generator_for_info, self).__init__()
53
+
54
+ self.input_layer = nn.Sequential(
55
+ nn.Conv2d(in_channels, 16, 3, stride=1, padding=1),
56
+ nn.LeakyReLU(0.2),
57
+ nn.InstanceNorm2d(16, affine=True),)
58
+
59
+ self.mid_layer = nn.Sequential(
60
+ *generator_block(16, 16, normalization=True),
61
+ *generator_block(16, 16, normalization=True),
62
+ *generator_block(16, 16, normalization=True),)
63
+
64
+ self.output_layer = nn.Sequential(
65
+ nn.Dropout(p=0.5),
66
+ nn.Conv2d(16, 1, 3, stride=1, padding=1),
67
+ nn.Sigmoid())
68
+
69
+ def forward(self, img_input):
70
+ x = self.input_layer(img_input)
71
+ identity = x
72
+ out = self.mid_layer(x)
73
+ out += identity
74
+ out = self.output_layer(out)
75
+ return out
76
+
77
+
78
+ def apply_fusion_4d_with_interpolation(visible_img, infrared_img, lut, get_context):
79
+
80
+ image_cat = torch.cat((visible_img, infrared_img), dim=1) # [0, 255]
81
+ context = get_context(image_cat)
82
+
83
+ context_scaled = (context *255./ 16.0).squeeze(1) # [0, 16]
84
+ infrared_scaled = infrared_img / 16.0 # [0, 16]
85
+
86
+ ycbcr_vis = rgb_to_ycbcr(visible_img / 255.) # [0, 1]
87
+ ycbcr_vis_scaled = ycbcr_vis * 255.0 / 16.0 # [0, 16]
88
+
89
+ y_vi_scaled = ycbcr_vis_scaled[:, 0, :, :] # [b, 1, h, w] # [0, 16]
90
+ cb_cr = ycbcr_vis[:, 1:, :, :] # [0, 1]
91
+ ir_scaled = infrared_scaled[:, 0, :, :] # [0, 16]
92
+
93
+ # 获取floor和ceil索引
94
+ ir_floor = torch.floor(ir_scaled).long()
95
+ ir_ceil = torch.clamp(ir_floor + 1, 0, lut.shape[3] - 1)
96
+ ir_alpha = ir_scaled - ir_floor
97
+
98
+ y_vi_floor = torch.floor(y_vi_scaled).long()
99
+ y_vi_ceil = torch.clamp(y_vi_floor + 1, 0, lut.shape[0] - 1)
100
+ y_vi_alpha = y_vi_scaled - y_vi_floor
101
+
102
+ c_floor = torch.floor(context_scaled).long()
103
+ c_ceil = torch.clamp(c_floor + 1, 0, lut.shape[2] - 1)
104
+ c_alpha = context_scaled - c_floor
105
+
106
+ sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(
107
+ visible_img.device)
108
+ sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(
109
+ visible_img.device)
110
+
111
+ # 计算 x 和 y 方向的梯度
112
+ grad_x = torch.nn.functional.conv2d(ycbcr_vis[:, :1, :, :], sobel_x, padding=1)
113
+ grad_y = torch.nn.functional.conv2d(ycbcr_vis[:, :1, :, :], sobel_y, padding=1)
114
+
115
+ # 计算梯度幅值
116
+ gradient = torch.sqrt(grad_x ** 2 + grad_y ** 2) # [b, 1, h, w]
117
+ min_val = gradient.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values
118
+ max_val = gradient.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values
119
+
120
+ gradient_normalized = (gradient - min_val) / (max_val - min_val + 1e-8)
121
+
122
+ gradient_scaled = (gradient_normalized * 255.)
123
+ gradient_scaled = (gradient_scaled / 16.0).squeeze(1)
124
+
125
+ g_floor = torch.floor(gradient_scaled).long()
126
+ g_ceil = torch.clamp(g_floor + 1, 0, lut.shape[1] - 1)
127
+ g_alpha = gradient_scaled - g_floor
128
+
129
+ ir_alpha = ir_alpha.unsqueeze(-1)
130
+ y_vi_alpha = y_vi_alpha.unsqueeze(-1)
131
+ g_alpha = g_alpha.unsqueeze(-1)
132
+ c_alpha = c_alpha.unsqueeze(-1)
133
+
134
+ def lerp(v1, v2, alpha):
135
+ out = v1 * (1 - alpha) + v2 * alpha
136
+ return out
137
+
138
+ fusion_result = (
139
+ lerp(
140
+ lerp(
141
+ lerp(
142
+ lerp(lut[y_vi_floor, g_floor, c_floor, ir_floor], lut[y_vi_floor, g_floor, c_floor, ir_ceil],
143
+ ir_alpha),
144
+ lerp(lut[y_vi_floor, g_floor, c_ceil, ir_floor], lut[y_vi_floor, g_floor, c_ceil, ir_ceil],
145
+ ir_alpha),
146
+ c_alpha,
147
+ ),
148
+ lerp(
149
+ lerp(lut[y_vi_floor, g_ceil, c_floor, ir_floor], lut[y_vi_floor, g_ceil, c_floor, ir_ceil],
150
+ ir_alpha),
151
+ lerp(lut[y_vi_floor, g_ceil, c_ceil, ir_floor], lut[y_vi_floor, g_ceil, c_ceil, ir_ceil], ir_alpha),
152
+ c_alpha,
153
+ ),
154
+ g_alpha,
155
+ ),
156
+ lerp(
157
+ lerp(
158
+ lerp(lut[y_vi_ceil, g_floor, c_floor, ir_floor], lut[y_vi_ceil, g_floor, c_floor, ir_ceil],
159
+ ir_alpha),
160
+ lerp(lut[y_vi_ceil, g_floor, c_ceil, ir_floor], lut[y_vi_ceil, g_floor, c_ceil, ir_ceil], ir_alpha),
161
+ c_alpha,
162
+ ),
163
+ lerp(
164
+ lerp(lut[y_vi_ceil, g_ceil, c_floor, ir_floor], lut[y_vi_ceil, g_ceil, c_floor, ir_ceil], ir_alpha),
165
+ lerp(lut[y_vi_ceil, g_ceil, c_ceil, ir_floor], lut[y_vi_ceil, g_ceil, c_ceil, ir_ceil], ir_alpha),
166
+ c_alpha,
167
+ ),
168
+ g_alpha,
169
+ ),
170
+ y_vi_alpha,
171
+ )
172
+ )
173
+
174
+ fusion_y = fusion_result.permute(0, 3, 1, 2)
175
+
176
+ fusion_ycbcr = torch.cat([fusion_y, cb_cr], dim=1)
177
+ fusion_rgb = ycbcr_to_rgb(fusion_ycbcr) # fusion_rgb = fusion_ycbcr.permute(0, 3, 1, 2)
178
+
179
+ return fusion_rgb
180
+
LUT-Fuse-main/scripts/loss_lut.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from math import exp
5
+
6
+
7
+ class fusion_loss(nn.Module):
8
+ def __init__(self):
9
+ super(fusion_loss, self).__init__()
10
+ self.loss_func_ssim = L_SSIM(window_size=48)
11
+ self.loss_func_Grad = GradientMaxLoss()
12
+ self.loss_func_l1 = nn.L1Loss() # 添加 L1 损失
13
+ self.loss_func_l2 = nn.MSELoss() # 添加 L2 损失
14
+ self.loss_func_color = L_color()
15
+ self.loss_func_Max = L_Intensity_Max_RGB()
16
+
17
+ def forward(self, image_vi, image_ir, image_fused, max_ratio=4, consist_ratio=1, ssim_ir_ratio=1,
18
+ ssim_ratio=1, ir_compose=1, color_ratio=12, text_ratio=2, max_mode="l1", consist_mode="l1"):
19
+ image_visible_gray = self.rgb2gray(image_vi)
20
+ image_infrared_gray = self.rgb2gray(image_ir)
21
+ image_fused_gray = self.rgb2gray(image_fused)
22
+ # loss_text = text_ratio * self.loss_func_Grad(image_visible_gray, image_infrared_gray, image_fused_gray)
23
+ # loss_max = max_ratio * self.loss_func_Max(image_vi, image_ir, image_fused, max_mode)
24
+ loss_ssim = ssim_ratio * (self.loss_func_ssim(image_vi, image_fused) + ssim_ir_ratio * self.loss_func_ssim(image_ir, image_fused_gray))
25
+ return loss_ssim
26
+
27
+ def rgb2gray(self, image):
28
+ b, c, h, w = image.size()
29
+ if c == 1:
30
+ return image
31
+ image_gray = 0.299 * image[:, 0, :, :] + 0.587 * image[:, 1, :, :] + 0.114 * image[:, 2, :, :]
32
+ image_gray = image_gray.unsqueeze(dim=1)
33
+ return image_gray
34
+
35
+ def gaussian(window_size, sigma):
36
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
37
+ return gauss / gauss.sum()
38
+
39
+
40
+ class L_Intensity_Max_RGB(nn.Module):
41
+ def __init__(self):
42
+ super(L_Intensity_Max_RGB, self).__init__()
43
+
44
+ def forward(self, image_visible, image_infrared, image_fused, max_mode="l1"):
45
+ gray_visible = torch.mean(image_visible, dim=1, keepdim=True)
46
+ gray_infrared = torch.mean(image_infrared, dim=1, keepdim=True)
47
+
48
+ mask = (gray_infrared > gray_visible).float()
49
+
50
+ fused_image = mask * image_infrared + (1 - mask) * image_visible
51
+ if max_mode == "l1":
52
+ Loss_intensity = F.l1_loss(fused_image, image_fused)
53
+ else:
54
+ Loss_intensity = F.mse_loss(fused_image, image_fused)
55
+ return Loss_intensity
56
+
57
+
58
+ def create_window(window_size, channel=1):
59
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
60
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
61
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
62
+ return window
63
+
64
+
65
+ def ssim(img1, img2, window_size=24, window=None, size_average=True, val_range=None):
66
+
67
+ if val_range is None:
68
+ if torch.max(img1) > 128:
69
+ max_val = 255
70
+ else:
71
+ max_val = 1
72
+
73
+ if torch.min(img1) < -0.5:
74
+ min_val = -1
75
+ else:
76
+ min_val = 0
77
+ L = max_val - min_val
78
+ else:
79
+ L = val_range
80
+
81
+ padd = 0
82
+ (_, channel, height, width) = img1.size()
83
+ if window is None:
84
+ real_size = min(window_size, height, width)
85
+ window = create_window(real_size, channel=channel).to(img1.device)
86
+
87
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
88
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
89
+
90
+ mu1_sq = mu1.pow(2)
91
+ mu2_sq = mu2.pow(2)
92
+ mu1_mu2 = mu1 * mu2
93
+
94
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
95
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
96
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
97
+
98
+ C1 = (0.01 * L) ** 2
99
+ C2 = (0.03 * L) ** 2
100
+
101
+ v1 = 2.0 * sigma12 + C2
102
+ v2 = sigma1_sq + sigma2_sq + C2
103
+ cs = torch.mean(v1 / v2) # contrast sensitivity
104
+
105
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
106
+
107
+ if size_average:
108
+ ret = ssim_map.mean()
109
+ else:
110
+ ret = ssim_map.mean(1).mean(1).mean(1)
111
+
112
+ return 1 - ret
113
+
114
+ class GradientMaxLoss(nn.Module):
115
+ def __init__(self):
116
+ super(GradientMaxLoss, self).__init__()
117
+ self.sobel_x = nn.Parameter(torch.FloatTensor([[-1, 0, 1],
118
+ [-2, 0, 2],
119
+ [-1, 0, 1]]).view(1, 1, 3, 3), requires_grad=False).cuda()
120
+ self.sobel_y = nn.Parameter(torch.FloatTensor([[-1, -2, -1],
121
+ [0, 0, 0],
122
+ [1, 2, 1]]).view(1, 1, 3, 3), requires_grad=False).cuda()
123
+ self.padding = (1, 1, 1, 1)
124
+
125
+ def forward(self, image_A, image_B, image_fuse):
126
+ gradient_A_x, gradient_A_y = self.gradient(image_A)
127
+ gradient_B_x, gradient_B_y = self.gradient(image_B)
128
+ gradient_fuse_x, gradient_fuse_y = self.gradient(image_fuse)
129
+ loss = F.l1_loss(gradient_fuse_x, torch.max(gradient_A_x, gradient_B_x)) + F.l1_loss(gradient_fuse_y, torch.max(gradient_A_y, gradient_B_y))
130
+ return loss
131
+
132
+ def gradient(self, image):
133
+ image = F.pad(image, self.padding, mode='replicate')
134
+ gradient_x = F.conv2d(image, self.sobel_x, padding=0)
135
+ gradient_y = F.conv2d(image, self.sobel_y, padding=0)
136
+ return torch.abs(gradient_x), torch.abs(gradient_y)
137
+
138
+ class L_SSIM(torch.nn.Module):
139
+ def __init__(self, window_size=11, size_average=True, val_range=None):
140
+ super(L_SSIM, self).__init__()
141
+ self.window_size = window_size
142
+ self.size_average = size_average
143
+ self.val_range = val_range
144
+
145
+ self.channel = 1
146
+ self.window = create_window(window_size)
147
+
148
+ def forward(self, img1, img2):
149
+ (_, channel, _, _) = img1.size()
150
+ (_, channel_2, _, _) = img2.size()
151
+
152
+ if channel != channel_2 and channel == 1:
153
+ img1 = torch.concat([img1, img1, img1], dim=1)
154
+ channel = 3
155
+
156
+ if channel == self.channel and self.window.dtype == img1.dtype:
157
+ window = self.window.cuda()
158
+ else:
159
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
160
+ self.window = window.cuda()
161
+ self.channel = channel
162
+
163
+ return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
164
+
165
+
166
+ class L_color(nn.Module):
167
+ def __init__(self):
168
+ super(L_color, self).__init__()
169
+
170
+ def forward(self, image_visible, image_fused):
171
+ ycbcr_visible = self.rgb_to_ycbcr(image_visible)
172
+ ycbcr_fused = self.rgb_to_ycbcr(image_fused)
173
+
174
+ cb_visible = ycbcr_visible[:, 1, :, :]
175
+ cr_visible = ycbcr_visible[:, 2, :, :]
176
+ cb_fused = ycbcr_fused[:, 1, :, :]
177
+ cr_fused = ycbcr_fused[:, 2, :, :]
178
+
179
+ loss_cb = F.l1_loss(cb_visible, cb_fused)
180
+ loss_cr = F.l1_loss(cr_visible, cr_fused)
181
+
182
+ loss_color = loss_cb + loss_cr
183
+ return loss_color
184
+
185
+ def rgb_to_ycbcr(self, image):
186
+ r = image[:, 0, :, :]
187
+ g = image[:, 1, :, :]
188
+ b = image[:, 2, :, :]
189
+
190
+ y = 0.299 * r + 0.587 * g + 0.114 * b
191
+ cb = -0.168736 * r - 0.331264 * g + 0.5 * b
192
+ cr = 0.5 * r - 0.418688 * g - 0.081312 * b
193
+
194
+ ycbcr_image = torch.stack((y, cb, cr), dim=1)
195
+ return ycbcr_image
196
+
LUT-Fuse-main/test_lut.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision.transforms import ToPILImage
6
+ import time
7
+ from data.simple_dataset import RandomCropPair
8
+ from data.simple_dataset import SimpleDataSet
9
+ import torch.nn as nn
10
+ import transforms as T
11
+ import os
12
+ from scripts.calculate import load_lookup_table, Generator_for_info, apply_fusion_4d_with_interpolation
13
+
14
+
15
+ def main():
16
+ lut_filepath = " "
17
+ context_file = " "
18
+ infrared_dir = " "
19
+ visible_dir = " "
20
+ save_dir = " "
21
+ os.makedirs(save_dir, exist_ok=True)
22
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ lut = load_lookup_table(lut_filepath).to(device)
26
+ if lut is None:
27
+ return
28
+
29
+ get_context = Generator_for_info()
30
+ get_context.load_state_dict(torch.load(context_file))
31
+ get_context = get_context.to(device)
32
+ get_context.eval()
33
+
34
+ data_transform = {
35
+ "train": RandomCropPair(size=(96, 96)),
36
+ "val": T.Compose([T.Resize_16(),
37
+ T.ToTensor()])}
38
+
39
+ val_dataset = SimpleDataSet(visible_path=visible_dir,
40
+ infrared_path=infrared_dir,
41
+ phase="val",
42
+ transform=data_transform["val"])
43
+ val_loader = torch.utils.data.DataLoader(val_dataset,
44
+ batch_size=1,
45
+ shuffle=False,
46
+ pin_memory=True,
47
+ num_workers=1,
48
+ collate_fn=val_dataset.collate_fn)
49
+
50
+ infrared_files = sorted(os.listdir(infrared_dir))
51
+ visible_files = sorted(os.listdir(visible_dir))
52
+
53
+ assert len(infrared_files) == len(visible_files), "The number of images in the infrared and visible folders do not match!"
54
+ target_size = (128, 128)
55
+ times = []
56
+
57
+ for step, data in enumerate(val_loader):
58
+ I_A, I_B, task = data
59
+
60
+ if torch.cuda.is_available():
61
+ I_A = I_A.to("cuda")
62
+ I_B = I_B.to("cuda")
63
+
64
+ torch.cuda.synchronize()
65
+ start_time = time.time()
66
+ outputs = apply_fusion_4d_with_interpolation(I_A * 255., I_B * 255., lut, get_context)
67
+ torch.cuda.synchronize()
68
+ end_time = time.time()
69
+ elapsed_time = end_time - start_time
70
+ times.append(elapsed_time)
71
+
72
+ if not os.path.splitext(task[0])[1]:
73
+ task_with_extension = task[0] + ".png"
74
+ else:
75
+ task_with_extension = task[0]
76
+ save_path = os.path.join(save_dir, task_with_extension)
77
+ fusion_result = outputs.squeeze(0).clamp(0, 1).cpu()
78
+ fusion_result_image = ToPILImage()(fusion_result)
79
+ fusion_result_image.save(save_path)
80
+
81
+
82
+ warmup_skip = 25
83
+ if len(times) > warmup_skip:
84
+ times_after_warmup = times[warmup_skip:]
85
+ avg_time = np.mean(times_after_warmup)
86
+ std_time = np.std(times_after_warmup)
87
+ print(f"Processing completed! after skipping the first {warmup_skip} images,avg_time: {avg_time:.4f} seconds,std_time: {std_time:.4f} seconds")
88
+ else:
89
+ print(f"Not enough images to skip the first {warmup_skip} !Total images: {len(times)}")
90
+
91
+ if __name__ == "__main__":
92
+ main()
93
+
LUT-Fuse-main/transforms.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+
4
+ import torch
5
+ from torchvision import transforms as T
6
+ from torchvision.transforms import functional as F
7
+
8
+
9
+ def pad_if_smaller(img, size, fill=0):
10
+ min_size = min(img.size)
11
+ if min_size < size:
12
+ ow, oh = img.size
13
+ padh = size - oh if oh < size else 0
14
+ padw = size - ow if ow < size else 0
15
+ img = F.pad(img, (0, 0, padw, padh), fill=fill)
16
+ return img
17
+
18
+
19
+ class Compose(object):
20
+ def __init__(self, transforms):
21
+ self.transforms = transforms
22
+
23
+ def __call__(self, image):
24
+ for t in self.transforms:
25
+ image= t(image)
26
+ return image
27
+
28
+
29
+ class Resize(object):
30
+ def __init__(self, size):
31
+ self.size = size
32
+
33
+ def __call__(self, image):
34
+ image = F.resize(image, self.size)
35
+ return image
36
+
37
+ class Resize_16(object):
38
+ def __init__(self):
39
+ pass
40
+
41
+ def __call__(self, image):
42
+ width, height = image.size
43
+ new_width = (width // 16) * 16
44
+ new_height = (height // 16) * 16
45
+
46
+ image = F.resize(image, (new_height, new_width))
47
+
48
+ return image
49
+
50
+
51
+ class Resize_20(object):
52
+ def __init__(self):
53
+ pass
54
+
55
+ def __call__(self, image):
56
+ width, height = image.size
57
+ new_width = (width // 20) * 20
58
+ new_height = (height // 20) * 20
59
+
60
+ image = F.resize(image, (new_height, new_width))
61
+
62
+ return image
63
+
64
+
65
+ class RandomHorizontalFlip(object):
66
+ def __init__(self, flip_prob):
67
+ self.flip_prob = flip_prob
68
+
69
+ def __call__(self, image):
70
+ if random.random() < self.flip_prob:
71
+ image = F.hflip(image)
72
+ return image
73
+
74
+
75
+ class RandomVerticalFlip(object):
76
+ def __init__(self, flip_prob):
77
+ self.flip_prob = flip_prob
78
+
79
+ def __call__(self, image):
80
+ if random.random() < self.flip_prob:
81
+ image = F.vflip(image)
82
+ return image
83
+
84
+
85
+ class RandomCrop(object):
86
+ def __init__(self, size):
87
+ self.size = size
88
+
89
+ def __call__(self, image):
90
+ image = pad_if_smaller(image, self.size)
91
+ crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
92
+ image = F.crop(image, *crop_params)
93
+ return image
94
+
95
+ class CenterCrop(object):
96
+ def __init__(self, size):
97
+ self.size = size
98
+
99
+ def __call__(self, image):
100
+ image = F.center_crop(image, self.size)
101
+
102
+ return image
103
+
104
+ class ToTensor(object):
105
+ def __call__(self, image):
106
+ image = F.to_tensor(image)
107
+ return image