heboya8 commited on
Commit
a08bf4a
·
verified ·
1 Parent(s): 766b1e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -2
app.py CHANGED
@@ -1,10 +1,62 @@
1
- from flask import Flask, render_template
 
 
 
 
 
2
 
3
  app = Flask(__name__)
4
 
 
 
 
 
 
 
 
 
 
5
  @app.route('/')
6
- def home():
7
  return render_template('index.html')
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  if __name__ == '__main__':
10
  app.run(host='0.0.0.0', port=7860)
 
1
+ from flask import Flask, render_template, request, redirect, url_for
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+ import os
6
+ import uuid
7
 
8
  app = Flask(__name__)
9
 
10
+ # Set upload folder
11
+ UPLOAD_FOLDER = 'static/uploads'
12
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
13
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
14
+
15
+ # Load DETR model and processor
16
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
17
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
18
+
19
  @app.route('/')
20
+ def index():
21
  return render_template('index.html')
22
 
23
+ @app.route('/upload', methods=['POST'])
24
+ def upload_file():
25
+ if 'file' not in request.files:
26
+ return redirect(request.url)
27
+ file = request.files['file']
28
+ if file.filename == '':
29
+ return redirect(request.url)
30
+
31
+ # Save the uploaded file
32
+ filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1]
33
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
34
+ file.save(filepath)
35
+
36
+ # Process image for object detection
37
+ image = Image.open(filepath).convert("RGB")
38
+ inputs = processor(images=image, return_tensors="pt")
39
+ outputs = model(**inputs)
40
+
41
+ # Post-process outputs
42
+ target_sizes = torch.tensor([image.size[::-1]])
43
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
44
+
45
+ # Draw bounding boxes
46
+ draw = ImageDraw.Draw(image)
47
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
48
+ box = [round(i, 2) for i in box.tolist()]
49
+ label_str = model.config.id2label[label.item()]
50
+ draw.rectangle(box, outline="red", width=3)
51
+ draw.text((box[0], box[1]), f"{label_str}: {score:.2f}", fill="red")
52
+
53
+ # Save output image
54
+ output_filename = f"output_{filename}"
55
+ output_filepath = os.path.join(app.config['UPLOAD_FOLDER'], output_filename)
56
+ image.save(output_filepath)
57
+
58
+ return render_template('results.html', original_image=url_for('static', filename=f'uploads/{filename}'),
59
+ processed_image=url_for('static', filename=f'uploads/{output_filename}'))
60
+
61
  if __name__ == '__main__':
62
  app.run(host='0.0.0.0', port=7860)