Spaces:
Paused
Paused
from flask import Flask, render_template, request, redirect, url_for | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
from PIL import Image, ImageDraw | |
import torch | |
import os | |
import uuid | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = Flask(__name__) | |
# Set upload folder | |
UPLOAD_FOLDER = 'static/uploads' | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
# Load DETR model and processor | |
logger.info("Loading DETR model and processor...") | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
logger.info("Model and processor loaded successfully.") | |
def index(): | |
return render_template('index.html') | |
def upload_file(): | |
if 'file' not in request.files: | |
logger.warning("No file part in request.") | |
return redirect(request.url) | |
file = request.files['file'] | |
if file.filename == '': | |
logger.warning("No file selected.") | |
return redirect(request.url) | |
try: | |
# Save uploaded file | |
filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1] | |
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
file.save(filepath) | |
logger.info(f"File saved: {filename}") | |
# Process image | |
image = Image.open(filepath).convert("RGB") | |
image = image.resize((800, 600)) # Resize for performance | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
# Post-process outputs | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
# Draw bounding boxes | |
draw = ImageDraw.Draw(image) | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
box = [round(i, 2) for i in box.tolist()] | |
label_str = model.config.id2label[label.item()] | |
draw.rectangle(box, outline="red", width=3) | |
draw.text((box[0], box[1]), f"{label_str}: {score:.2f}", fill="red") | |
# Save output image | |
output_filename = f"output_{filename}" | |
output_filepath = os.path.join(app.config['UPLOAD_FOLDER'], output_filename) | |
image.save(output_filepath) | |
logger.info(f"Processed image saved: {output_filename}") | |
return render_template('results.html', | |
original_image=url_for('static', filename=f'uploads/{filename}'), | |
processed_image=url_for('static', filename=f'uploads/{output_filename}')) | |
except Exception as e: | |
logger.error(f"Error processing file: {str(e)}") | |
return render_template('index.html', error=f"Error processing file: {str(e)}") | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) |