soiz1 commited on
Commit
66d414f
·
verified ·
1 Parent(s): f56281b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -18
app.py CHANGED
@@ -7,7 +7,19 @@ from gfpgan.utils import GFPGANer
7
  from realesrgan.utils import RealESRGANer
8
  import uuid
9
  import tempfile
10
- # ウェイトファイルをダウンロード(存在しない場合)
 
 
 
 
 
 
 
 
 
 
 
 
11
  if not os.path.exists('realesr-general-x4v3.pth'):
12
  os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
13
  if not os.path.exists('GFPGANv1.2.pth'):
@@ -31,6 +43,36 @@ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, ti
31
 
32
  os.makedirs('output', exist_ok=True)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @app.route('/api/restore', methods=['POST'])
35
  def restore_image():
36
  try:
@@ -41,7 +83,7 @@ def restore_image():
41
  file = request.files['file']
42
  version = request.form.get('version', 'v1.4')
43
  scale = float(request.form.get('scale', 2))
44
- # weight = float(request.form.get('weight', 50)) / 100 # CodeFormer用のweightパラメータが必要な場合
45
 
46
  # 一時ファイルに保存
47
  temp_dir = tempfile.mkdtemp()
@@ -49,7 +91,7 @@ def restore_image():
49
  file.save(input_path)
50
 
51
  # 画像処理
52
- extension = os.path.splitext(os.path.basename(str(input_path)))[1]
53
  img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
54
 
55
  if len(img.shape) == 3 and img.shape[2] == 4:
@@ -68,25 +110,26 @@ def restore_image():
68
  if version == 'v1.2':
69
  face_enhancer = GFPGANer(
70
  model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
71
  elif version == 'v1.3':
72
  face_enhancer = GFPGANer(
73
  model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
74
  elif version == 'v1.4':
75
  face_enhancer = GFPGANer(
76
  model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
77
  elif version == 'RestoreFormer':
78
  face_enhancer = GFPGANer(
79
  model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
 
80
  elif version == 'CodeFormer':
81
- face_enhancer = GFPGANer(
82
- model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
83
  elif version == 'RealESR-General-x4v3':
84
  face_enhancer = GFPGANer(
85
- model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler)
 
86
 
87
- # 画像を拡張
88
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
89
-
90
  # スケール調整
91
  if scale != 2:
92
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
@@ -179,6 +222,10 @@ def index():
179
  <label for="scale">Rescaling factor:</label>
180
  <input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required>
181
  </div>
 
 
 
 
182
  <button type="submit" id="submitButton">Process Image</button>
183
  </form>
184
  <div id="loading" class="loader"></div>
@@ -198,11 +245,23 @@ def index():
198
  </div>
199
 
200
  <script>
 
 
 
 
 
 
 
 
 
 
 
201
  // フォームの変更を監視してAPI使用例を更新
202
  function updateApiUsage() {
203
  const fileInput = document.getElementById('file');
204
  const version = document.getElementById('version').value;
205
  const scale = document.getElementById('scale').value;
 
206
 
207
  // 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
208
  const baseUrl = window.location.origin;
@@ -216,27 +275,32 @@ def index():
216
  reader.onload = function(e) {
217
  const dataURL = e.target.result;
218
  if (dataURL.length > 40) {
219
- filePreview = `"${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}"`;
220
  } else {
221
- filePreview = `"${dataURL}"`;
222
  }
223
- updateFetchCode(apiUrl, version, scale, filePreview);
224
  };
225
  reader.readAsDataURL(file);
226
  } else {
227
- updateFetchCode(apiUrl, version, scale, filePreview);
228
  }
229
  }
230
 
231
- function updateFetchCode(apiUrl, version, scale, filePreview) {
232
  const fetchCodeDiv = document.getElementById('fetchCode');
233
- fetchCodeDiv.innerHTML = `
234
- // JavaScript fetch example:
235
  const formData = new FormData();
236
  formData.append('file', ${filePreview});
237
  formData.append('version', '${version}');
238
- formData.append('scale', ${scale});
 
 
 
 
 
239
 
 
240
  fetch('${apiUrl}', {
241
  method: 'POST',
242
  body: formData
@@ -256,12 +320,15 @@ fetch('${apiUrl}', {
256
  .catch(error => {
257
  console.error('Error:', error.message);
258
  });`;
 
 
259
  }
260
 
261
  // フォーム要素の変更を監視
262
  document.getElementById('file').addEventListener('change', updateApiUsage);
263
  document.getElementById('version').addEventListener('change', updateApiUsage);
264
  document.getElementById('scale').addEventListener('input', updateApiUsage);
 
265
 
266
  // 初期表示
267
  updateApiUsage();
@@ -281,6 +348,11 @@ fetch('${apiUrl}', {
281
  formData.append('version', document.getElementById('version').value);
282
  formData.append('scale', document.getElementById('scale').value);
283
 
 
 
 
 
 
284
  // 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
285
  const baseUrl = window.location.origin;
286
  const apiUrl = baseUrl + '/api/restore';
@@ -321,5 +393,4 @@ fetch('${apiUrl}', {
321
  """
322
 
323
  if __name__ == '__main__':
324
-
325
  app.run(host='0.0.0.0', port=7860, debug=True)
 
7
  from realesrgan.utils import RealESRGANer
8
  import uuid
9
  import tempfile
10
+ from torchvision.transforms.functional import normalize
11
+ from torchvision import transforms
12
+ from PIL import Image
13
+ from basicsr.utils import img2tensor, tensor2img
14
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
15
+ from codeformer.archs.codeformer_arch import CodeFormer
16
+
17
+ # 依存関係のインストール
18
+ os.system("git clone https://github.com/sczhou/CodeFormer.git")
19
+ os.system("cd CodeFormer && pip install -r requirements.txt")
20
+ os.system("cd CodeFormer && python basicsr/setup.py develop")
21
+
22
+ # ウェイトファイルをダウンロード(毎回消えるので毎回必ず実行。)
23
  if not os.path.exists('realesr-general-x4v3.pth'):
24
  os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
25
  if not os.path.exists('GFPGANv1.2.pth'):
 
43
 
44
  os.makedirs('output', exist_ok=True)
45
 
46
+ def restore_with_codeformer(img, scale=2, weight=0.5):
47
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
+ net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
49
+ net.load_state_dict(torch.load('CodeFormer.pth')['params_ema'])
50
+ net.eval()
51
+
52
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
53
+ img = Image.fromarray(img)
54
+
55
+ face_helper = FaceRestoreHelper(
56
+ upscale_factor=scale, face_size=512, crop_ratio=(1, 1), use_parse=True,
57
+ device=device)
58
+
59
+ face_helper.clean_all()
60
+ face_helper.read_image(img)
61
+ face_helper.get_face_landmarks_5(only_center_face=False, resize=640)
62
+ face_helper.align_warp_face()
63
+
64
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
65
+ cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=False, float32=True)
66
+ normalize(cropped_face_t, [0.5], [0.5], inplace=True)
67
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
68
+ with torch.no_grad():
69
+ output = net(cropped_face_t, w=weight, adain=True)[0]
70
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
71
+ face_helper.add_restored_face(restored_face)
72
+
73
+ restored_img = face_helper.paste_faces_to_input_image()
74
+ return cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
75
+
76
  @app.route('/api/restore', methods=['POST'])
77
  def restore_image():
78
  try:
 
83
  file = request.files['file']
84
  version = request.form.get('version', 'v1.4')
85
  scale = float(request.form.get('scale', 2))
86
+ weight = float(request.form.get('weight', 0.5)) # CodeFormer用のweightパラメータ
87
 
88
  # 一時ファイルに保存
89
  temp_dir = tempfile.mkdtemp()
 
91
  file.save(input_path)
92
 
93
  # 画像処理
94
+ extension = os.path.splitext(os.path.basename(str(input_path))[1]
95
  img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
96
 
97
  if len(img.shape) == 3 and img.shape[2] == 4:
 
110
  if version == 'v1.2':
111
  face_enhancer = GFPGANer(
112
  model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
113
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
114
  elif version == 'v1.3':
115
  face_enhancer = GFPGANer(
116
  model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
117
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
118
  elif version == 'v1.4':
119
  face_enhancer = GFPGANer(
120
  model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
121
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
122
  elif version == 'RestoreFormer':
123
  face_enhancer = GFPGANer(
124
  model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
125
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
126
  elif version == 'CodeFormer':
127
+ output = restore_with_codeformer(img, scale=scale, weight=weight)
 
128
  elif version == 'RealESR-General-x4v3':
129
  face_enhancer = GFPGANer(
130
+ model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler, map_location=torch.device('cpu'))
131
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
132
 
 
 
 
133
  # スケール調整
134
  if scale != 2:
135
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
 
222
  <label for="scale">Rescaling factor:</label>
223
  <input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required>
224
  </div>
225
+ <div class="form-group" id="weightGroup" style="display: none;">
226
+ <label for="weight">CodeFormer Weight (0-1):</label>
227
+ <input type="number" id="weight" name="weight" value="0.5" step="0.1" min="0" max="1">
228
+ </div>
229
  <button type="submit" id="submitButton">Process Image</button>
230
  </form>
231
  <div id="loading" class="loader"></div>
 
245
  </div>
246
 
247
  <script>
248
+ // CodeFormerが選択された時にweightパラメータを表示
249
+ document.getElementById('version').addEventListener('change', function() {
250
+ const weightGroup = document.getElementById('weightGroup');
251
+ if (this.value === 'CodeFormer') {
252
+ weightGroup.style.display = 'block';
253
+ } else {
254
+ weightGroup.style.display = 'none';
255
+ }
256
+ updateApiUsage();
257
+ });
258
+
259
  // フォームの変更を監視してAPI使用例を更新
260
  function updateApiUsage() {
261
  const fileInput = document.getElementById('file');
262
  const version = document.getElementById('version').value;
263
  const scale = document.getElementById('scale').value;
264
+ const weight = document.getElementById('weight').value;
265
 
266
  // 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
267
  const baseUrl = window.location.origin;
 
275
  reader.onload = function(e) {
276
  const dataURL = e.target.result;
277
  if (dataURL.length > 40) {
278
+ filePreview = "${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}";
279
  } else {
280
+ filePreview = "${dataURL}";
281
  }
282
+ updateFetchCode(apiUrl, version, scale, weight, filePreview);
283
  };
284
  reader.readAsDataURL(file);
285
  } else {
286
+ updateFetchCode(apiUrl, version, scale, weight, filePreview);
287
  }
288
  }
289
 
290
+ function updateFetchCode(apiUrl, version, scale, weight, filePreview) {
291
  const fetchCodeDiv = document.getElementById('fetchCode');
292
+ let code = `// JavaScript fetch example:
 
293
  const formData = new FormData();
294
  formData.append('file', ${filePreview});
295
  formData.append('version', '${version}');
296
+ formData.append('scale', ${scale});`;
297
+
298
+ if (version === 'CodeFormer') {
299
+ code += `
300
+ formData.append('weight', ${weight});`;
301
+ }
302
 
303
+ code += `
304
  fetch('${apiUrl}', {
305
  method: 'POST',
306
  body: formData
 
320
  .catch(error => {
321
  console.error('Error:', error.message);
322
  });`;
323
+
324
+ fetchCodeDiv.innerHTML = code;
325
  }
326
 
327
  // フォーム要素の変更を監視
328
  document.getElementById('file').addEventListener('change', updateApiUsage);
329
  document.getElementById('version').addEventListener('change', updateApiUsage);
330
  document.getElementById('scale').addEventListener('input', updateApiUsage);
331
+ document.getElementById('weight').addEventListener('input', updateApiUsage);
332
 
333
  // 初期表示
334
  updateApiUsage();
 
348
  formData.append('version', document.getElementById('version').value);
349
  formData.append('scale', document.getElementById('scale').value);
350
 
351
+ // CodeFormerが選択されている場合はweightも追加
352
+ if (document.getElementById('version').value === 'CodeFormer') {
353
+ formData.append('weight', document.getElementById('weight').value);
354
+ }
355
+
356
  // 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
357
  const baseUrl = window.location.origin;
358
  const apiUrl = baseUrl + '/api/restore';
 
393
  """
394
 
395
  if __name__ == '__main__':
 
396
  app.run(host='0.0.0.0', port=7860, debug=True)