Spaces:
Running
Running
Using huggingface-hosted models
Browse files
app.py
CHANGED
@@ -9,6 +9,8 @@ import gradio as gr
|
|
9 |
from ultralytics import YOLO
|
10 |
from transformers import ResNetModel
|
11 |
import cv2
|
|
|
|
|
12 |
|
13 |
class FlakeLayerClassifier(nn.Module):
|
14 |
def __init__(self, num_materials, material_dim, num_classes=4, dropout_prob=0.1, freeze_cnn=False):
|
@@ -71,14 +73,18 @@ print(f"Using device: {device}")
|
|
71 |
#yolo = YOLO("/home/sankalp/flake_classification/models/best.pt")
|
72 |
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt")
|
73 |
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt")
|
74 |
-
|
|
|
75 |
yolo.conf = 0.5
|
76 |
|
77 |
# Load classifier weights
|
78 |
-
|
79 |
-
"/
|
80 |
-
|
|
|
81 |
)
|
|
|
|
|
82 |
num_classes = len(ckpt["class_to_idx"])
|
83 |
classifier = FlakeLayerClassifier(
|
84 |
num_materials=num_classes,
|
|
|
9 |
from ultralytics import YOLO
|
10 |
from transformers import ResNetModel
|
11 |
import cv2
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
|
14 |
|
15 |
class FlakeLayerClassifier(nn.Module):
|
16 |
def __init__(self, num_materials, material_dim, num_classes=4, dropout_prob=0.1, freeze_cnn=False):
|
|
|
73 |
#yolo = YOLO("/home/sankalp/flake_classification/models/best.pt")
|
74 |
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt")
|
75 |
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt")
|
76 |
+
yolo_path = hf_hub_download(repo_id="sanpdy/yolo-flake-detector", filename="yolo-flake-detector-MSU.pt", token=False)
|
77 |
+
yolo = YOLO(yolo_path)
|
78 |
yolo.conf = 0.5
|
79 |
|
80 |
# Load classifier weights
|
81 |
+
classifier_path = hf_hub_download(
|
82 |
+
repo_id="sanpdy/flake-classifier",
|
83 |
+
filename="flake-classifier.pth",
|
84 |
+
token=False
|
85 |
)
|
86 |
+
ckpt = torch.load(classifier_path, map_location=device)
|
87 |
+
|
88 |
num_classes = len(ckpt["class_to_idx"])
|
89 |
classifier = FlakeLayerClassifier(
|
90 |
num_materials=num_classes,
|