|
|
|
|
|
|
|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
from deepface import DeepFace |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import tempfile |
|
import os |
|
import shutil |
|
import pandas as pd |
|
|
|
|
|
try: |
|
from google.colab import drive |
|
drive.mount('/content/drive') |
|
except: |
|
pass |
|
|
|
def verify_faces(img1, img2, threshold=0.6, model="VGG-Face"): |
|
temp_dir = tempfile.mkdtemp() |
|
try: |
|
|
|
img1_path = os.path.join(temp_dir, "img1.jpg") |
|
img2_path = os.path.join(temp_dir, "img2.jpg") |
|
Image.fromarray(img1).save(img1_path) if isinstance(img1, np.ndarray) else img1.save(img1_path) |
|
Image.fromarray(img2).save(img2_path) if isinstance(img2, np.ndarray) else img2.save(img2_path) |
|
|
|
|
|
result = DeepFace.verify( |
|
img1_path=img1_path, |
|
img2_path=img2_path, |
|
model_name=model, |
|
distance_metric="cosine" |
|
) |
|
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) |
|
for i, path in enumerate([img1_path, img2_path]): |
|
img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) |
|
ax[i].imshow(img) |
|
ax[i].axis('off') |
|
ax[i].set_title(f"Image {i+1}") |
|
|
|
verified = result['distance'] <= threshold |
|
plt.suptitle(f"{'β
MATCH' if verified else 'β NO MATCH'}\nDistance: {result['distance']:.4f}") |
|
return fig, result |
|
|
|
except Exception as e: |
|
return None, {"error": str(e)} |
|
finally: |
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
def find_faces(query_img, db_input, threshold=0.6, model="VGG-Face"): |
|
temp_dir = tempfile.mkdtemp() |
|
try: |
|
|
|
query_path = os.path.join(temp_dir, "query.jpg") |
|
Image.fromarray(query_img).save(query_path) if isinstance(query_img, np.ndarray) else query_img.save(query_path) |
|
|
|
|
|
if isinstance(db_input, str): |
|
db_path = db_input |
|
else: |
|
db_path = os.path.join(temp_dir, "db") |
|
os.makedirs(db_path, exist_ok=True) |
|
for i, file in enumerate(db_input): |
|
ext = os.path.splitext(file.name)[1] |
|
shutil.copy(file.name, os.path.join(db_path, f"img_{i}{ext}")) |
|
|
|
|
|
try: |
|
dfs = DeepFace.find( |
|
img_path=query_path, |
|
db_path=db_path, |
|
model_name=model, |
|
distance_metric="cosine", |
|
silent=True |
|
) |
|
except: |
|
return None, {"error": "No faces found in database"} |
|
|
|
df = dfs[0] if isinstance(dfs, list) else dfs |
|
df = df[df['distance'] <= threshold].sort_values('distance') |
|
|
|
|
|
num_matches = min(4, len(df)) |
|
fig, axes = plt.subplots(1, num_matches + 1, figsize=(15, 5)) |
|
|
|
|
|
query_img = cv2.cvtColor(cv2.imread(query_path), cv2.COLOR_BGR2RGB) |
|
axes[0].imshow(query_img) |
|
axes[0].set_title("Query") |
|
axes[0].axis('off') |
|
|
|
|
|
for i in range(num_matches): |
|
if i >= len(df): break |
|
match_path = df.iloc[i]['identity'] |
|
match_img = cv2.cvtColor(cv2.imread(match_path), cv2.COLOR_BGR2RGB) |
|
axes[i+1].imshow(match_img) |
|
axes[i+1].set_title(f"Match {i+1}\n{df.iloc[i]['distance']:.4f}") |
|
axes[i+1].axis('off') |
|
|
|
return fig, df[['identity', 'distance']].to_dict('records') |
|
|
|
except Exception as e: |
|
return None, {"error": str(e)} |
|
finally: |
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
def analyze_face(img, actions=['age', 'gender', 'emotion']): |
|
temp_dir = tempfile.mkdtemp() |
|
try: |
|
|
|
img_path = os.path.join(temp_dir, "analyze.jpg") |
|
Image.fromarray(img).save(img_path) if isinstance(img, np.ndarray) else img.save(img_path) |
|
|
|
|
|
results = DeepFace.analyze( |
|
img_path=img_path, |
|
actions=actions, |
|
enforce_detection=False, |
|
detector_backend='opencv' |
|
) |
|
|
|
|
|
results = results if isinstance(results, list) else [results] |
|
fig = plt.figure(figsize=(10, 5)) |
|
|
|
|
|
plt.subplot(121) |
|
img_display = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) |
|
plt.imshow(img_display) |
|
plt.title("Input Image") |
|
plt.axis('off') |
|
|
|
|
|
plt.subplot(122) |
|
attributes = {k: v for res in results for k, v in res.items() if k != 'region'} |
|
plt.barh(list(attributes.keys()), list(attributes.values())) |
|
plt.title("Analysis Results") |
|
plt.tight_layout() |
|
|
|
return fig, results |
|
|
|
except Exception as e: |
|
return None, {"error": str(e)} |
|
finally: |
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
with gr.Blocks(title="Face Recognition Toolkit", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# π§π» Face Recognition Toolkit") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Verify Faces"): |
|
with gr.Row(): |
|
img1 = gr.Image(label="First Image", type="pil") |
|
img2 = gr.Image(label="Second Image", type="pil") |
|
verify_threshold = gr.Slider(0.1, 1.0, 0.6, label="Match Threshold") |
|
verify_model = gr.Dropdown(["VGG-Face", "Facenet", "OpenFace"], value="VGG-Face") |
|
verify_btn = gr.Button("Verify Faces") |
|
verify_output = gr.Plot() |
|
verify_json = gr.JSON() |
|
|
|
verify_btn.click( |
|
verify_faces, |
|
[img1, img2, verify_threshold, verify_model], |
|
[verify_output, verify_json] |
|
) |
|
|
|
with gr.Tab("Find Faces"): |
|
query_img = gr.Image(label="Query Image", type="pil") |
|
db_input = gr.Textbox("/content/drive/MyDrive/db", label="Database Path") |
|
db_files = gr.File(file_count="multiple", label="Or Upload Images") |
|
find_threshold = gr.Slider(0.1, 1.0, 0.6, label="Similarity Threshold") |
|
find_model = gr.Dropdown(["VGG-Face", "Facenet", "OpenFace"], value="VGG-Face") |
|
find_btn = gr.Button("Find Matches") |
|
find_output = gr.Plot() |
|
find_json = gr.JSON() |
|
|
|
find_btn.click( |
|
find_faces, |
|
[query_img, db_input, find_threshold, find_model], |
|
[find_output, find_json] |
|
) |
|
db_files.change(lambda x: None, db_files, db_input) |
|
|
|
with gr.Tab("Analyze Face"): |
|
analyze_img = gr.Image(label="Input Image", type="pil") |
|
analyze_actions = gr.CheckboxGroup( |
|
["age", "gender", "emotion", "race"], |
|
value=["age", "gender", "emotion"], |
|
label="Analysis Features" |
|
) |
|
analyze_btn = gr.Button("Analyze") |
|
analyze_output = gr.Plot() |
|
analyze_json = gr.JSON() |
|
|
|
analyze_btn.click( |
|
analyze_face, |
|
[analyze_img, analyze_actions], |
|
[analyze_output, analyze_json] |
|
) |
|
|
|
demo.launch() |