Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,19 @@ from gfpgan.utils import GFPGANer
|
|
7 |
from realesrgan.utils import RealESRGANer
|
8 |
import uuid
|
9 |
import tempfile
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
if not os.path.exists('realesr-general-x4v3.pth'):
|
12 |
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
|
13 |
if not os.path.exists('GFPGANv1.2.pth'):
|
@@ -31,6 +43,36 @@ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, ti
|
|
31 |
|
32 |
os.makedirs('output', exist_ok=True)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
@app.route('/api/restore', methods=['POST'])
|
35 |
def restore_image():
|
36 |
try:
|
@@ -41,7 +83,7 @@ def restore_image():
|
|
41 |
file = request.files['file']
|
42 |
version = request.form.get('version', 'v1.4')
|
43 |
scale = float(request.form.get('scale', 2))
|
44 |
-
|
45 |
|
46 |
# 一時ファイルに保存
|
47 |
temp_dir = tempfile.mkdtemp()
|
@@ -49,7 +91,7 @@ def restore_image():
|
|
49 |
file.save(input_path)
|
50 |
|
51 |
# 画像処理
|
52 |
-
extension = os.path.splitext(os.path.basename(str(input_path))
|
53 |
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
|
54 |
|
55 |
if len(img.shape) == 3 and img.shape[2] == 4:
|
@@ -68,25 +110,26 @@ def restore_image():
|
|
68 |
if version == 'v1.2':
|
69 |
face_enhancer = GFPGANer(
|
70 |
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
71 |
elif version == 'v1.3':
|
72 |
face_enhancer = GFPGANer(
|
73 |
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
74 |
elif version == 'v1.4':
|
75 |
face_enhancer = GFPGANer(
|
76 |
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
77 |
elif version == 'RestoreFormer':
|
78 |
face_enhancer = GFPGANer(
|
79 |
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
80 |
elif version == 'CodeFormer':
|
81 |
-
|
82 |
-
model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
83 |
elif version == 'RealESR-General-x4v3':
|
84 |
face_enhancer = GFPGANer(
|
85 |
-
model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
86 |
|
87 |
-
# 画像を拡張
|
88 |
-
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
89 |
-
|
90 |
# スケール調整
|
91 |
if scale != 2:
|
92 |
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
|
@@ -179,6 +222,10 @@ def index():
|
|
179 |
<label for="scale">Rescaling factor:</label>
|
180 |
<input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required>
|
181 |
</div>
|
|
|
|
|
|
|
|
|
182 |
<button type="submit" id="submitButton">Process Image</button>
|
183 |
</form>
|
184 |
<div id="loading" class="loader"></div>
|
@@ -198,11 +245,23 @@ def index():
|
|
198 |
</div>
|
199 |
|
200 |
<script>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
// フォームの変更を監視してAPI使用例を更新
|
202 |
function updateApiUsage() {
|
203 |
const fileInput = document.getElementById('file');
|
204 |
const version = document.getElementById('version').value;
|
205 |
const scale = document.getElementById('scale').value;
|
|
|
206 |
|
207 |
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
|
208 |
const baseUrl = window.location.origin;
|
@@ -216,27 +275,32 @@ def index():
|
|
216 |
reader.onload = function(e) {
|
217 |
const dataURL = e.target.result;
|
218 |
if (dataURL.length > 40) {
|
219 |
-
filePreview =
|
220 |
} else {
|
221 |
-
filePreview =
|
222 |
}
|
223 |
-
updateFetchCode(apiUrl, version, scale, filePreview);
|
224 |
};
|
225 |
reader.readAsDataURL(file);
|
226 |
} else {
|
227 |
-
updateFetchCode(apiUrl, version, scale, filePreview);
|
228 |
}
|
229 |
}
|
230 |
|
231 |
-
function updateFetchCode(apiUrl, version, scale, filePreview) {
|
232 |
const fetchCodeDiv = document.getElementById('fetchCode');
|
233 |
-
|
234 |
-
// JavaScript fetch example:
|
235 |
const formData = new FormData();
|
236 |
formData.append('file', ${filePreview});
|
237 |
formData.append('version', '${version}');
|
238 |
-
formData.append('scale', ${scale})
|
|
|
|
|
|
|
|
|
|
|
239 |
|
|
|
240 |
fetch('${apiUrl}', {
|
241 |
method: 'POST',
|
242 |
body: formData
|
@@ -256,12 +320,15 @@ fetch('${apiUrl}', {
|
|
256 |
.catch(error => {
|
257 |
console.error('Error:', error.message);
|
258 |
});`;
|
|
|
|
|
259 |
}
|
260 |
|
261 |
// フォーム要素の変更を監視
|
262 |
document.getElementById('file').addEventListener('change', updateApiUsage);
|
263 |
document.getElementById('version').addEventListener('change', updateApiUsage);
|
264 |
document.getElementById('scale').addEventListener('input', updateApiUsage);
|
|
|
265 |
|
266 |
// 初期表示
|
267 |
updateApiUsage();
|
@@ -281,6 +348,11 @@ fetch('${apiUrl}', {
|
|
281 |
formData.append('version', document.getElementById('version').value);
|
282 |
formData.append('scale', document.getElementById('scale').value);
|
283 |
|
|
|
|
|
|
|
|
|
|
|
284 |
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
|
285 |
const baseUrl = window.location.origin;
|
286 |
const apiUrl = baseUrl + '/api/restore';
|
@@ -321,5 +393,4 @@ fetch('${apiUrl}', {
|
|
321 |
"""
|
322 |
|
323 |
if __name__ == '__main__':
|
324 |
-
|
325 |
app.run(host='0.0.0.0', port=7860, debug=True)
|
|
|
7 |
from realesrgan.utils import RealESRGANer
|
8 |
import uuid
|
9 |
import tempfile
|
10 |
+
from torchvision.transforms.functional import normalize
|
11 |
+
from torchvision import transforms
|
12 |
+
from PIL import Image
|
13 |
+
from basicsr.utils import img2tensor, tensor2img
|
14 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
15 |
+
from codeformer.archs.codeformer_arch import CodeFormer
|
16 |
+
|
17 |
+
# 依存関係のインストール
|
18 |
+
os.system("git clone https://github.com/sczhou/CodeFormer.git")
|
19 |
+
os.system("cd CodeFormer && pip install -r requirements.txt")
|
20 |
+
os.system("cd CodeFormer && python basicsr/setup.py develop")
|
21 |
+
|
22 |
+
# ウェイトファイルをダウンロード(毎回消えるので毎回必ず実行。)
|
23 |
if not os.path.exists('realesr-general-x4v3.pth'):
|
24 |
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
|
25 |
if not os.path.exists('GFPGANv1.2.pth'):
|
|
|
43 |
|
44 |
os.makedirs('output', exist_ok=True)
|
45 |
|
46 |
+
def restore_with_codeformer(img, scale=2, weight=0.5):
|
47 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
48 |
+
net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
|
49 |
+
net.load_state_dict(torch.load('CodeFormer.pth')['params_ema'])
|
50 |
+
net.eval()
|
51 |
+
|
52 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
53 |
+
img = Image.fromarray(img)
|
54 |
+
|
55 |
+
face_helper = FaceRestoreHelper(
|
56 |
+
upscale_factor=scale, face_size=512, crop_ratio=(1, 1), use_parse=True,
|
57 |
+
device=device)
|
58 |
+
|
59 |
+
face_helper.clean_all()
|
60 |
+
face_helper.read_image(img)
|
61 |
+
face_helper.get_face_landmarks_5(only_center_face=False, resize=640)
|
62 |
+
face_helper.align_warp_face()
|
63 |
+
|
64 |
+
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
65 |
+
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=False, float32=True)
|
66 |
+
normalize(cropped_face_t, [0.5], [0.5], inplace=True)
|
67 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
68 |
+
with torch.no_grad():
|
69 |
+
output = net(cropped_face_t, w=weight, adain=True)[0]
|
70 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
71 |
+
face_helper.add_restored_face(restored_face)
|
72 |
+
|
73 |
+
restored_img = face_helper.paste_faces_to_input_image()
|
74 |
+
return cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
|
75 |
+
|
76 |
@app.route('/api/restore', methods=['POST'])
|
77 |
def restore_image():
|
78 |
try:
|
|
|
83 |
file = request.files['file']
|
84 |
version = request.form.get('version', 'v1.4')
|
85 |
scale = float(request.form.get('scale', 2))
|
86 |
+
weight = float(request.form.get('weight', 0.5)) # CodeFormer用のweightパラメータ
|
87 |
|
88 |
# 一時ファイルに保存
|
89 |
temp_dir = tempfile.mkdtemp()
|
|
|
91 |
file.save(input_path)
|
92 |
|
93 |
# 画像処理
|
94 |
+
extension = os.path.splitext(os.path.basename(str(input_path))[1]
|
95 |
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
|
96 |
|
97 |
if len(img.shape) == 3 and img.shape[2] == 4:
|
|
|
110 |
if version == 'v1.2':
|
111 |
face_enhancer = GFPGANer(
|
112 |
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
113 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
114 |
elif version == 'v1.3':
|
115 |
face_enhancer = GFPGANer(
|
116 |
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
117 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
118 |
elif version == 'v1.4':
|
119 |
face_enhancer = GFPGANer(
|
120 |
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
121 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
122 |
elif version == 'RestoreFormer':
|
123 |
face_enhancer = GFPGANer(
|
124 |
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
125 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
126 |
elif version == 'CodeFormer':
|
127 |
+
output = restore_with_codeformer(img, scale=scale, weight=weight)
|
|
|
128 |
elif version == 'RealESR-General-x4v3':
|
129 |
face_enhancer = GFPGANer(
|
130 |
+
model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler, map_location=torch.device('cpu'))
|
131 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
132 |
|
|
|
|
|
|
|
133 |
# スケール調整
|
134 |
if scale != 2:
|
135 |
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
|
|
|
222 |
<label for="scale">Rescaling factor:</label>
|
223 |
<input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required>
|
224 |
</div>
|
225 |
+
<div class="form-group" id="weightGroup" style="display: none;">
|
226 |
+
<label for="weight">CodeFormer Weight (0-1):</label>
|
227 |
+
<input type="number" id="weight" name="weight" value="0.5" step="0.1" min="0" max="1">
|
228 |
+
</div>
|
229 |
<button type="submit" id="submitButton">Process Image</button>
|
230 |
</form>
|
231 |
<div id="loading" class="loader"></div>
|
|
|
245 |
</div>
|
246 |
|
247 |
<script>
|
248 |
+
// CodeFormerが選択された時にweightパラメータを表示
|
249 |
+
document.getElementById('version').addEventListener('change', function() {
|
250 |
+
const weightGroup = document.getElementById('weightGroup');
|
251 |
+
if (this.value === 'CodeFormer') {
|
252 |
+
weightGroup.style.display = 'block';
|
253 |
+
} else {
|
254 |
+
weightGroup.style.display = 'none';
|
255 |
+
}
|
256 |
+
updateApiUsage();
|
257 |
+
});
|
258 |
+
|
259 |
// フォームの変更を監視してAPI使用例を更新
|
260 |
function updateApiUsage() {
|
261 |
const fileInput = document.getElementById('file');
|
262 |
const version = document.getElementById('version').value;
|
263 |
const scale = document.getElementById('scale').value;
|
264 |
+
const weight = document.getElementById('weight').value;
|
265 |
|
266 |
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
|
267 |
const baseUrl = window.location.origin;
|
|
|
275 |
reader.onload = function(e) {
|
276 |
const dataURL = e.target.result;
|
277 |
if (dataURL.length > 40) {
|
278 |
+
filePreview = "${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}";
|
279 |
} else {
|
280 |
+
filePreview = "${dataURL}";
|
281 |
}
|
282 |
+
updateFetchCode(apiUrl, version, scale, weight, filePreview);
|
283 |
};
|
284 |
reader.readAsDataURL(file);
|
285 |
} else {
|
286 |
+
updateFetchCode(apiUrl, version, scale, weight, filePreview);
|
287 |
}
|
288 |
}
|
289 |
|
290 |
+
function updateFetchCode(apiUrl, version, scale, weight, filePreview) {
|
291 |
const fetchCodeDiv = document.getElementById('fetchCode');
|
292 |
+
let code = `// JavaScript fetch example:
|
|
|
293 |
const formData = new FormData();
|
294 |
formData.append('file', ${filePreview});
|
295 |
formData.append('version', '${version}');
|
296 |
+
formData.append('scale', ${scale});`;
|
297 |
+
|
298 |
+
if (version === 'CodeFormer') {
|
299 |
+
code += `
|
300 |
+
formData.append('weight', ${weight});`;
|
301 |
+
}
|
302 |
|
303 |
+
code += `
|
304 |
fetch('${apiUrl}', {
|
305 |
method: 'POST',
|
306 |
body: formData
|
|
|
320 |
.catch(error => {
|
321 |
console.error('Error:', error.message);
|
322 |
});`;
|
323 |
+
|
324 |
+
fetchCodeDiv.innerHTML = code;
|
325 |
}
|
326 |
|
327 |
// フォーム要素の変更を監視
|
328 |
document.getElementById('file').addEventListener('change', updateApiUsage);
|
329 |
document.getElementById('version').addEventListener('change', updateApiUsage);
|
330 |
document.getElementById('scale').addEventListener('input', updateApiUsage);
|
331 |
+
document.getElementById('weight').addEventListener('input', updateApiUsage);
|
332 |
|
333 |
// 初期表示
|
334 |
updateApiUsage();
|
|
|
348 |
formData.append('version', document.getElementById('version').value);
|
349 |
formData.append('scale', document.getElementById('scale').value);
|
350 |
|
351 |
+
// CodeFormerが選択されている場合はweightも追加
|
352 |
+
if (document.getElementById('version').value === 'CodeFormer') {
|
353 |
+
formData.append('weight', document.getElementById('weight').value);
|
354 |
+
}
|
355 |
+
|
356 |
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
|
357 |
const baseUrl = window.location.origin;
|
358 |
const apiUrl = baseUrl + '/api/restore';
|
|
|
393 |
"""
|
394 |
|
395 |
if __name__ == '__main__':
|
|
|
396 |
app.run(host='0.0.0.0', port=7860, debug=True)
|