sanpdy commited on
Commit
f9f1c14
·
1 Parent(s): 15ab5d3

Using huggingface-hosted models

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -9,6 +9,8 @@ import gradio as gr
9
  from ultralytics import YOLO
10
  from transformers import ResNetModel
11
  import cv2
 
 
12
 
13
  class FlakeLayerClassifier(nn.Module):
14
  def __init__(self, num_materials, material_dim, num_classes=4, dropout_prob=0.1, freeze_cnn=False):
@@ -71,14 +73,18 @@ print(f"Using device: {device}")
71
  #yolo = YOLO("/home/sankalp/flake_classification/models/best.pt")
72
  #yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt")
73
  #yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt")
74
- yolo = YOLO("models/yolo-flake-detector.pt")
 
75
  yolo.conf = 0.5
76
 
77
  # Load classifier weights
78
- ckpt = torch.load(
79
- "/home/sankalp/flake_classification/models/flake_classifier.pth",
80
- map_location=device
 
81
  )
 
 
82
  num_classes = len(ckpt["class_to_idx"])
83
  classifier = FlakeLayerClassifier(
84
  num_materials=num_classes,
 
9
  from ultralytics import YOLO
10
  from transformers import ResNetModel
11
  import cv2
12
+ from huggingface_hub import hf_hub_download
13
+
14
 
15
  class FlakeLayerClassifier(nn.Module):
16
  def __init__(self, num_materials, material_dim, num_classes=4, dropout_prob=0.1, freeze_cnn=False):
 
73
  #yolo = YOLO("/home/sankalp/flake_classification/models/best.pt")
74
  #yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt")
75
  #yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt")
76
+ yolo_path = hf_hub_download(repo_id="sanpdy/yolo-flake-detector", filename="yolo-flake-detector-MSU.pt", token=False)
77
+ yolo = YOLO(yolo_path)
78
  yolo.conf = 0.5
79
 
80
  # Load classifier weights
81
+ classifier_path = hf_hub_download(
82
+ repo_id="sanpdy/flake-classifier",
83
+ filename="flake-classifier.pth",
84
+ token=False
85
  )
86
+ ckpt = torch.load(classifier_path, map_location=device)
87
+
88
  num_classes = len(ckpt["class_to_idx"])
89
  classifier = FlakeLayerClassifier(
90
  num_materials=num_classes,