soiz1's picture
Update app.py
886a7d7 verified
import os
import cv2
import torch
from flask import Flask, request, jsonify, send_file, render_template_string
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
import tempfile
import uuid
app = Flask(__name__)
# Initialize models
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
# Ensure output directory exists
os.makedirs('output', exist_ok=True)
# Download weights if not exists
def download_weights():
weights = {
'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth',
'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth',
'CodeFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth'
}
for weight_file, url in weights.items():
if not os.path.exists(weight_file):
os.system(f"wget {url} -O {weight_file}")
download_weights()
def process_image(img_path, version, scale, weight=0.5):
try:
extension = os.path.splitext(os.path.basename(str(img_path)))[1]
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2:
img_mode = None
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img_mode = None
h, w = img.shape[0:2]
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
if version == 'v1.2':
face_enhancer = GFPGANer(
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'v1.3':
face_enhancer = GFPGANer(
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'v1.4':
face_enhancer = GFPGANer(
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'RestoreFormer':
face_enhancer = GFPGANer(
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'CodeFormer':
face_enhancer = GFPGANer(
model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'RealESR-General-x4v3':
face_enhancer = GFPGANer(
model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler)
try:
if version == 'CodeFormer':
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
else:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
except RuntimeError as error:
print('Error', error)
raise Exception(f"Enhancement error: {str(error)}")
try:
if scale != 2:
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = img.shape[0:2]
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
except Exception as error:
print('wrong scale input.', error)
# Save to temporary file
output_filename = f"output_{uuid.uuid4().hex}.jpg"
output_path = os.path.join('output', output_filename)
if img_mode == 'RGBA':
cv2.imwrite(output_path, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
else:
cv2.imwrite(output_path, output, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
return output_path
except Exception as error:
print('Global exception', error)
raise Exception(f"Processing error: {str(error)}")
@app.route('/')
def index():
return render_template_string('''
<!DOCTYPE html>
<html>
<head>
<title>Image Upscaling & Restoration API</title>
<style>
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
.container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; }
.form-group { margin-bottom: 15px; }
label { display: block; margin-bottom: 5px; }
input, select { width: 100%; padding: 8px; box-sizing: border-box; }
button { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; border-radius: 4px; cursor: pointer; }
button:hover { background-color: #45a049; }
#result { margin-top: 20px; }
#preview { max-width: 100%; margin-top: 10px; }
#apiUsage { background-color: #f5f5f5; padding: 15px; border-radius: 5px; margin-top: 20px; font-family: monospace; white-space: pre-wrap; }
#apiUsage h3 { margin-top: 0; }
#formDataPreview { max-height: 200px; overflow-y: auto; margin-bottom: 10px; }
.code-block { background-color: #f8f8f8; padding: 10px; border-radius: 4px; border-left: 3px solid #4CAF50; }
.comment { color: #666; font-style: italic; }
.loader {
width: 48px;
height: 48px;
border: 5px solid #4CAF50;
border-bottom-color: transparent;
border-radius: 50%;
display: inline-block;
box-sizing: border-box;
animation: rotation 1s linear infinite;
margin: 20px auto;
display: none; /* 初期状態では非表示 */
}
@keyframes rotation {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
</style>
</head>
<body>
<h1>Image Upscaling & Restoration API</h1>
<div class="container">
<form id="uploadForm" enctype="multipart/form-data">
<div class="form-group">
<label for="file">Upload Image:</label>
<input type="file" id="file" name="file" required>
</div>
<div class="form-group">
<label for="version">Version:</label>
<select id="version" name="version">
<option value="v1.2">GFPGANv1.2</option>
<option value="v1.3">GFPGANv1.3</option>
<option value="v1.4" selected>GFPGANv1.4</option>
<option value="RestoreFormer">RestoreFormer</option>
<option value="CodeFormer">CodeFormer</option>
<option value="RealESR-General-x4v3">RealESR-General-x4v3</option>
</select>
</div>
<div class="form-group">
<label for="scale">Rescaling factor:</label>
<input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required>
</div>
<div class="form-group" id="weightGroup" style="display: none;">
<label for="weight">CodeFormer Weight (0-1):</label>
<input type="number" id="weight" name="weight" value="0.5" step="0.1" min="0" max="1">
</div>
<button type="submit" id="submitButton">Process Image</button>
</form>
<div id="loading" class="loader"></div>
<div id="result">
<h3>Result:</h3>
<div id="outputContainer" style="display: none;">
<img id="preview" src="" alt="Processed Image">
<a id="downloadLink" href="#" download>Download Image</a>
</div>
</div>
<div id="apiUsage">
<h3>API Usage:</h3>
<div id="fetchCode" class="code-block">
// JavaScript fetch code will appear here
</div>
</div>
</div>
<script>
// CodeFormerが選択された時にweightパラメータを表示
document.getElementById('version').addEventListener('change', function() {
const weightGroup = document.getElementById('weightGroup');
if (this.value === 'CodeFormer') {
weightGroup.style.display = 'block';
} else {
weightGroup.style.display = 'none';
}
updateApiUsage();
});
// フォームの変更を監視してAPI使用例を更新
function updateApiUsage() {
const fileInput = document.getElementById('file');
const version = document.getElementById('version').value;
const scale = document.getElementById('scale').value;
const weight = document.getElementById('weight').value;
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
const baseUrl = window.location.origin;
const apiUrl = baseUrl + '/api/restore';
// ファイルのプレビュー用文字列を準備
let filePreview = '"img-dataURL"';
if (fileInput.files.length > 0) {
const file = fileInput.files[0];
const reader = new FileReader();
reader.onload = function(e) {
const dataURL = e.target.result;
if (dataURL.length > 40) {
filePreview = `"${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}"`;
} else {
filePreview = `"${dataURL}"`;
}
updateFetchCode(apiUrl, version, scale, weight, filePreview);
};
reader.readAsDataURL(file);
} else {
updateFetchCode(apiUrl, version, scale, weight, filePreview);
}
}
function updateFetchCode(apiUrl, version, scale, weight, filePreview) {
const fetchCodeDiv = document.getElementById('fetchCode');
let code = `// JavaScript fetch example:
const formData = new FormData();
formData.append('file', ${filePreview});
formData.append('version', '${version}');
formData.append('scale', ${scale});`;
if (version === 'CodeFormer') {
code += `
formData.append('weight', ${weight});`;
}
code += `
fetch('${apiUrl}', {
method: 'POST',
body: formData
})
.then(response => {
if (!response.ok) {
return response.json().then(err => { throw new Error(err.error || 'Unknown error'); });
}
return response.blob();
})
.then(blob => {
// Process the returned image blob
const url = URL.createObjectURL(blob);
console.log('Image processed successfully', url);
// Example: document.getElementById('resultImage').src = url;
})
.catch(error => {
console.error('Error:', error.message);
});`;
fetchCodeDiv.innerHTML = code;
}
// フォーム要素の変更を監視
document.getElementById('file').addEventListener('change', updateApiUsage);
document.getElementById('version').addEventListener('change', updateApiUsage);
document.getElementById('scale').addEventListener('input', updateApiUsage);
document.getElementById('weight').addEventListener('input', updateApiUsage);
// 初期表示
updateApiUsage();
document.getElementById('uploadForm').addEventListener('submit', function(e) {
e.preventDefault();
// ボタンを無効化し、ローディングを表示
const submitButton = document.getElementById('submitButton');
const loadingElement = document.getElementById('loading');
submitButton.disabled = true;
loadingElement.style.display = 'block';
const formData = new FormData();
formData.append('file', document.getElementById('file').files[0]);
formData.append('version', document.getElementById('version').value);
formData.append('scale', document.getElementById('scale').value);
// CodeFormerが選択されている場合はweightも追加
if (document.getElementById('version').value === 'CodeFormer') {
formData.append('weight', document.getElementById('weight').value);
}
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
const baseUrl = window.location.origin;
const apiUrl = baseUrl + '/api/restore';
fetch(apiUrl, {
method: 'POST',
body: formData
})
.then(response => {
if (!response.ok) {
return response.json().then(err => { throw new Error(err.error || 'Unknown error'); });
}
return response.blob();
})
.then(blob => {
const url = URL.createObjectURL(blob);
const preview = document.getElementById('preview');
const downloadLink = document.getElementById('downloadLink');
const outputContainer = document.getElementById('outputContainer');
preview.src = url;
downloadLink.href = url;
downloadLink.download = 'restored_' + document.getElementById('file').files[0].name;
outputContainer.style.display = 'block';
})
.catch(error => {
alert('Error: ' + error.message);
})
.finally(() => {
// 処理が終わったらローディングを非表示にし、ボタンを再有効化
loadingElement.style.display = 'none';
submitButton.disabled = false;
});
});
</script>
</body>
</html>
''')
@app.route('/api/restore', methods=['POST'])
def api_restore():
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
file = request.files['file']
version = request.form.get('version', 'v1.4')
scale = float(request.form.get('scale', 2))
weight = float(request.form.get('weight', 0.5)) if version == 'CodeFormer' else None
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
try:
# Save uploaded file to temp location
temp_dir = tempfile.mkdtemp()
input_path = os.path.join(temp_dir, file.filename)
file.save(input_path)
# Process image
output_path = process_image(input_path, version, scale, weight)
# Return the processed image
return send_file(output_path, mimetype='image/jpeg')
except Exception as e:
return jsonify({'error': str(e)}), 500
finally:
# Clean up temp files
if 'input_path' in locals() and os.path.exists(input_path):
os.remove(input_path)
if 'temp_dir' in locals() and os.path.exists(temp_dir):
os.rmdir(temp_dir)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)