akhaliq HF Staff commited on
Commit
5736f4d
·
1 Parent(s): aa46e5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ from PIL import Image
5
+ import cv2
6
+ import os
7
+ import gradio as gr
8
+
9
+ import mxnet
10
+ from mxnet.gluon.data.vision import transforms
11
+
12
+ os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
13
+
14
+
15
+ with open('synset.txt', 'r') as f:
16
+ labels = [l.rstrip() for l in f]
17
+
18
+ os.system("wget https://github.com/AK391/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx")
19
+
20
+ os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")
21
+
22
+
23
+
24
+ model_path = 'shufflenet-9.onnx'
25
+ model = onnx.load(model_path)
26
+ session = ort.InferenceSession(model.SerializeToString())
27
+
28
+ def get_image(path):
29
+ with Image.open(path) as img:
30
+ img = np.array(img.convert('RGB'))
31
+ return img
32
+
33
+ def preprocess(img):
34
+ '''
35
+ Preprocessing required on the images for inference with mxnet gluon
36
+ The function takes path to an image and returns processed tensor
37
+ '''
38
+ transform_fn = transforms.Compose([
39
+ transforms.Resize(224),
40
+ transforms.CenterCrop(224),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
43
+ ])
44
+ img = mxnet.ndarray.array(img)
45
+ img = transform_fn(img)
46
+ img = img.expand_dims(axis=0) # batchify
47
+
48
+ return img.asnumpy()
49
+
50
+
51
+ def predict(path):
52
+ img = get_image(path)
53
+ img = preprocess(img)
54
+ ort_inputs = {session.get_inputs()[0].name: img}
55
+ preds = session.run(None, ort_inputs)[0]
56
+ preds = np.squeeze(preds)
57
+ a = np.argsort(preds)
58
+ results = {}
59
+ for i in a[0:5]:
60
+ results[labels[a[i]]] = float(preds[a[i]])
61
+ return results
62
+
63
+
64
+ title="ShuffleNet-v1"
65
+ description="ShuffleNet is a deep convolutional network for image classification. ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification."
66
+
67
+ examples=[['kitten.jpg']]
68
+ gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)