3v324v23 commited on
Commit
8164b8a
·
1 Parent(s): a09a133

visualization update

Browse files
Files changed (3) hide show
  1. app.py +16 -4
  2. mmdet/core/visualization/image.py +4 -5
  3. requirements.txt +9 -0
app.py CHANGED
@@ -3,6 +3,13 @@ import numpy as np
3
  import gradio as gr
4
  from infer import detections
5
 
 
 
 
 
 
 
 
6
  def walt_demo(input_img):
7
  #detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
8
  detect = detections('configs/walt/walt_vehicle.py', 'cuda:0', model_path='data/models/walt_vehicle.pth')
@@ -34,15 +41,19 @@ article="""
34
 
35
  examples = [
36
  'demo/images/img_1.jpg',
 
 
 
37
  ]
38
 
39
-
40
  import cv2
41
  filename='demo/images/img_1.jpg'
42
  img=cv2.imread(filename)
43
  img=walt_demo(img)
44
- cv2.imwrite(filename.replace('demo','demo/results/'),img)
45
-
 
46
  demo = gr.Interface(walt_demo,
47
  gr.Image(),
48
  "image",
@@ -52,6 +63,7 @@ demo = gr.Interface(walt_demo,
52
  examples=examples,
53
  description=description)
54
 
55
- demo.launch(server_name="0.0.0.0", server_port=7000)
 
56
 
57
 
 
3
  import gradio as gr
4
  from infer import detections
5
 
6
+
7
+ import os
8
+ os.system("mkdir data")
9
+ os.system("mkdir data/models")
10
+ os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
11
+ os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
12
+
13
  def walt_demo(input_img):
14
  #detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
15
  detect = detections('configs/walt/walt_vehicle.py', 'cuda:0', model_path='data/models/walt_vehicle.pth')
 
41
 
42
  examples = [
43
  'demo/images/img_1.jpg',
44
+ 'demo/images/img_2.jpg',
45
+ 'demo/images/img_3.png',
46
+ 'demo/images/img_4.png',
47
  ]
48
 
49
+ '''
50
  import cv2
51
  filename='demo/images/img_1.jpg'
52
  img=cv2.imread(filename)
53
  img=walt_demo(img)
54
+ cv2.imwrite(filename.replace('/images/','/results/'),img)
55
+ cv2.imwrite('check.png',img)
56
+ '''
57
  demo = gr.Interface(walt_demo,
58
  gr.Image(),
59
  "image",
 
63
  examples=examples,
64
  description=description)
65
 
66
+ #demo.launch(server_name="0.0.0.0", server_port=7000)
67
+ demo.launch()
68
 
69
 
mmdet/core/visualization/image.py CHANGED
@@ -165,13 +165,12 @@ def imshow_det_bboxes(img,
165
  color_mask = mask_colors[np.random.randint(0, 99)]
166
  mask = segms[len(labels)*ll+i].astype(bool)
167
  show_border = True
168
- #img[mask] = img[mask] * 0.5 + color_mask * 0.5
169
  if show_border:
170
  contours,_ = cv2.findContours(mask.copy().astype('uint8'), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
171
- border_thick = min(int(4*(max(bbox_int[2]-bbox_int[0],bbox_int[3]-bbox_int[1])/200))+1,6)
172
- print(border_thick, bbox_int)
173
- cv2.drawContours(img_bound, contours, -1, (int(color_mask[0][0]),int(color_mask[0][1]),int(color_mask[0][2])), border_thick)
174
- img = cv2.addWeighted(img,1.0,img_bound,0.6,0)
175
 
176
  #img[img_bound>0] = img_bound
177
 
 
165
  color_mask = mask_colors[np.random.randint(0, 99)]
166
  mask = segms[len(labels)*ll+i].astype(bool)
167
  show_border = True
168
+ img[mask] = img[mask] * 0.5 + color_mask * 0.5
169
  if show_border:
170
  contours,_ = cv2.findContours(mask.copy().astype('uint8'), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
171
+ border_thick = min(int(4*(max(bbox_int[2]-bbox_int[0],bbox_int[3]-bbox_int[1])/300))+1,6)
172
+ cv2.drawContours(img, contours, -1, (int(color_mask[0][0]),int(color_mask[0][1]),int(color_mask[0][2])), border_thick)
173
+ #img = cv2.addWeighted(img,1.0,img_bound,1.0,0)
 
174
 
175
  #img[img_bound>0] = img_bound
176
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.9.0
2
+ mmcv-full==1.4.0
3
+ git+https://github.com/open-mmlab/mmdetection.git@7bd39044f35aec4b90dd797b965777541a8678ff
4
+ gradio==3.0.20
5
+ timm
6
+ scikit-image
7
+ imagesize
8
+ torchvision==0.10.0
9
+ imantics