# app.py import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import torchvision.transforms as T from PIL import Image, ImageDraw, ImageFont import gradio as gr from ultralytics import YOLO from transformers import ResNetModel import cv2 from huggingface_hub import hf_hub_download class FlakeLayerClassifier(nn.Module): def __init__(self, num_materials, material_dim, num_classes=4, dropout_prob=0.1, freeze_cnn=False): super().__init__() self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18") if freeze_cnn: for p in self.cnn.parameters(): p.requires_grad = False img_feat_dim = self.cnn.config.hidden_sizes[-1] self.material_embedding = nn.Embedding(num_materials, material_dim) self.dropout = nn.Dropout(dropout_prob) self.fc_img = nn.Sequential( nn.Linear(img_feat_dim, img_feat_dim), nn.ReLU(inplace=True), self.dropout, nn.Linear(img_feat_dim, num_classes) ) combined_dim = img_feat_dim + material_dim self.fc_comb = nn.Sequential( nn.Linear(combined_dim, combined_dim), nn.ReLU(inplace=True), self.dropout, nn.Linear(combined_dim, num_classes) ) def forward(self, pixel_values, material=None): outputs = self.cnn(pixel_values=pixel_values) img_feats = outputs.pooler_output.view(outputs.pooler_output.size(0), -1) if material is None: return self.fc_img(img_feats) mat_emb = self.material_embedding(material) combined = torch.cat([img_feats, mat_emb], dim=1) return self.fc_comb(combined) def calibration(source_img, target_img): source_lab = cv2.cvtColor(source_img, cv2.COLOR_BGR2LAB) target_lab = cv2.cvtColor(target_img, cv2.COLOR_BGR2LAB) for i in range(3): src_mean, src_std = cv2.meanStdDev(source_lab[:, :, i]) tgt_mean, tgt_std = cv2.meanStdDev(target_lab[:, :, i]) target_lab[:, :, i] = ( (target_lab[:, :, i] - tgt_mean) * (src_std / tgt_std) + src_mean ).clip(0, 255) corrected_img = cv2.cvtColor(target_lab, cv2.COLOR_LAB2BGR) return corrected_img.astype(np.uint8) device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load YOLO detector #yolo = YOLO("/home/sankalp/flake_classification/models/best.pt") #yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt") #yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt") yolo_path = hf_hub_download(repo_id="sanpdy/yolo-flake-detector", filename="yolo-flake-detector-MSU.pt", token=False) yolo = YOLO(yolo_path) yolo.conf = 0.5 # Load classifier weights classifier_path = hf_hub_download( repo_id="sanpdy/flake-classifier", filename="flake_classifier_5layer.pth", token=False ) ckpt = torch.load(classifier_path, map_location=device) num_classes = len(ckpt["class_to_idx"]) classifier = FlakeLayerClassifier( num_materials=num_classes, material_dim=64, num_classes=num_classes, dropout_prob=0.1, freeze_cnn=False ).to(device) classifier.load_state_dict(ckpt["model_state_dict"]) classifier.eval() # Image processing transforms clf_tf = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) try: FONT = ImageFont.truetype("arial.ttf", 20) except IOError: FONT = ImageFont.load_default() # Inference + drawing def detect_and_classify(image: Image.Image): #image = calibration( # np.array(Image.open("/home/sankalp/gradio_flake_app/quantum-flake-pipeline/template/image.png")), #np.array(image.convert("RGB")), #) #image = Image.fromarray(image) img_rgb = np.array(image.convert("RGB")) img_bgr = img_rgb[:, :, ::-1] results = yolo(img_bgr, device=str(device)) boxes = results[0].boxes.xyxy.cpu().numpy() scores = results[0].boxes.conf.cpu().numpy() draw = ImageDraw.Draw(image) for (x1, y1, x2, y2), conf in zip(boxes, scores): crop = image.crop((x1, y1, x2, y2)) inp = clf_tf(crop).unsqueeze(0).to(device) # (1,C,H,W) with torch.no_grad(): logits = classifier(pixel_values=inp) pred = logits.argmax(1).item() prob = F.softmax(logits, dim=1)[0, pred].item() label = f"Layer {pred+1} ({prob:.2f})" # draw draw.rectangle([x1, y1, x2, y2], outline="red", width=2) draw.text((x1, max(0, y1-18)), label, fill="red", font=FONT) return image # Gradio UI demo = gr.Interface( fn=detect_and_classify, inputs=gr.Image(type="pil", label="Upload Flake Image"), outputs=gr.Image(type="pil", label="Annotated Output"), title="Flake Detection + Layer Classification", description="Upload an image → YOLO finds flakes → ResNet-18 head classifies their layer.", ) if __name__ == "__main__": demo.launch(share=True)