Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import time | |
| import torch | |
| import numpy as np | |
| import requests | |
| import gradio as gr | |
| import wikipedia | |
| from io import BytesIO | |
| from PIL import Image, ImageDraw | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import diffusers | |
| from ultralytics import YOLO | |
| from realesrgan import RealESRGANer | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from huggingface_hub import hf_hub_download | |
| # Determine device | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| # --------------------------- | |
| # Load models from Hugging Face | |
| # --------------------------- | |
| # YOLO detection model β load the .pt file from a Hugging Face repo. | |
| # Replace "your-hf-username/your-yolov8-model" and "model.pt" with your actual repo id and filename. | |
| yolo_weights_path = hf_hub_download(repo_id="FathomNet/MBARI-315k-yolov8", filename="mbari_315k_yolov8.pt") | |
| yolo_model = YOLO(yolo_weights_path) | |
| # QA pipeline (for Ask Eurybia) | |
| qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
| # Gemma model (if needed) β already loaded from Hugging Face: | |
| # gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it") | |
| # gemma_model = AutoModelForCausalLM.from_pretrained( | |
| # "google/gemma-2-2b-it", device_map="auto", | |
| # torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32) | |
| # Depth estimation model (using Diffusers) | |
| if device == "cuda": | |
| depth_pipe = diffusers.MarigoldDepthPipeline.from_pretrained( | |
| "prs-eth/marigold-depth-lcm-v1-0", | |
| variant="fp16", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| else: | |
| depth_pipe = diffusers.MarigoldDepthPipeline.from_pretrained( | |
| "prs-eth/marigold-depth-lcm-v1-0" | |
| ).to(device) | |
| # RealESRGAN upscaling model β download weights from Hugging Face. | |
| # (Ensure that the repo_id and filename point to a valid model on Hugging Face.) | |
| upscaler_weight_path = hf_hub_download(repo_id="RealESRGAN/RealESRGAN_x4plus", filename="RealESRGAN_x4plus.pth") | |
| model_rrdb = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, | |
| num_block=23, num_grow_ch=32, scale=4) | |
| upscaler = RealESRGANer( | |
| scale=4, | |
| model_path=upscaler_weight_path, | |
| model=model_rrdb, | |
| pre_pad=0, | |
| half=(device == "cuda"), | |
| device=device | |
| ) | |
| # --------------------------- | |
| # Define functional endpoints | |
| # --------------------------- | |
| def detect_objects(input_image): | |
| """ | |
| Runs YOLO detection on an input image, draws bounding boxes, and returns | |
| both the processed image and detection info. | |
| """ | |
| # Convert PIL to NumPy array and then to BGR (OpenCV format) | |
| image_np = np.array(input_image) | |
| image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| # Run detection with a lower confidence threshold (0.075) | |
| results = yolo_model.predict(source=image_bgr, conf=0.075)[0] | |
| # Create a copy for drawing | |
| image_out = image_bgr.copy() | |
| detection_info = "" | |
| if results.boxes is not None: | |
| for box in results.boxes: | |
| x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
| class_name = yolo_model.names[int(box.cls)] | |
| confidence = box.conf.item() * 100 | |
| detection_info += f"{class_name}: {confidence:.2f}%\n" | |
| cv2.rectangle(image_out, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
| cv2.putText(image_out, f"{class_name} {confidence:.2f}%", | |
| (x1, max(y1 - 10, 0)), cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.9, (0, 0, 255), 2) | |
| else: | |
| detection_info = "No detections found." | |
| # Convert back to RGB for display in Gradio | |
| image_out_rgb = cv2.cvtColor(image_out, cv2.COLOR_BGR2RGB) | |
| output_image = Image.fromarray(image_out_rgb) | |
| return output_image, detection_info | |
| def get_object_info(class_name): | |
| """ | |
| Looks up the given class name on Wikipedia and returns a short description | |
| and an image (if one is found). | |
| """ | |
| wikipedia.set_lang("en") | |
| wikipedia.set_rate_limiting(True) | |
| try: | |
| page = wikipedia.page(class_name) | |
| description = page.content[:5000] | |
| img_url = None | |
| for img in page.images: | |
| if img.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')): | |
| img_url = img | |
| break | |
| if img_url: | |
| response = requests.get(img_url) | |
| info_image = Image.open(BytesIO(response.content)) | |
| else: | |
| info_image = None | |
| except Exception as e: | |
| description = f"Error fetching info: {e}" | |
| info_image = None | |
| return description, info_image | |
| def ask_eurybia(context, question): | |
| """ | |
| Uses the QA pipeline to answer a question given a context. | |
| """ | |
| try: | |
| answer = qa_pipeline(question=question, context=context) | |
| if not answer['answer'].strip(): | |
| return "Unknown" | |
| return answer['answer'] | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def enhance_image(input_image): | |
| """ | |
| Enhances (upscales) the input image using the RealESRGAN model. | |
| """ | |
| try: | |
| # Ensure the image is in RGB | |
| img_np = np.array(input_image.convert("RGB")) | |
| output, _ = upscaler.enhance(img_np, outscale=4) | |
| enhanced_image = Image.fromarray(output) | |
| return enhanced_image | |
| except Exception as e: | |
| return f"Error during enhancement: {e}" | |
| def predict_depth(input_image): | |
| """ | |
| Predicts a depth map from the input image using the Diffusers pipeline. | |
| """ | |
| try: | |
| image_rgb = input_image.convert("RGB") | |
| result = depth_pipe(image_rgb) | |
| depth_prediction = result.prediction | |
| vis_depth = depth_pipe.image_processor.visualize_depth(depth_prediction) | |
| # Assume the first image is the desired output | |
| depth_img = vis_depth[0] | |
| return depth_img | |
| except Exception as e: | |
| # If an error occurs, create a blank image with the error message. | |
| img = Image.new("RGB", (400, 300), color=(255, 255, 255)) | |
| draw = ImageDraw.Draw(img) | |
| draw.text((10, 150), f"Error: {e}", fill=(255, 0, 0)) | |
| return img | |
| # --------------------------- | |
| # Build the Gradio Interface | |
| # --------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Eurybia Mini") | |
| gr.Markdown("This Gradio app replicates the functionalities of your original Tkinter app. " | |
| "The YOLO and upscaling model weights are now loaded from Hugging Face.") | |
| with gr.Tabs(): | |
| with gr.Tab("Object Detection"): | |
| gr.Markdown("Upload an image for object detection.") | |
| with gr.Row(): | |
| input_image = gr.Image(label="Input Image", source="upload", type="pil") | |
| output_image = gr.Image(label="Detected Image") | |
| detection_text = gr.Textbox(label="Detection Info") | |
| btn_detect = gr.Button("Detect") | |
| btn_detect.click(detect_objects, inputs=input_image, outputs=[output_image, detection_text]) | |
| with gr.Tab("Object Info"): | |
| gr.Markdown("Enter a class name to fetch info from Wikipedia.") | |
| class_input = gr.Textbox(label="Class Name") | |
| info_text = gr.Textbox(label="Description") | |
| info_image = gr.Image(label="Info Image") | |
| btn_info = gr.Button("Get Info") | |
| btn_info.click(get_object_info, inputs=class_input, outputs=[info_text, info_image]) | |
| with gr.Tab("Ask Eurybia"): | |
| gr.Markdown("Provide a context and ask a question.") | |
| context_input = gr.Textbox(label="Context", lines=10, placeholder="Paste context (e.g., detection info) here...") | |
| question_input = gr.Textbox(label="Question", placeholder="Enter your question here...") | |
| answer_output = gr.Textbox(label="Answer") | |
| btn_ask = gr.Button("Ask") | |
| btn_ask.click(ask_eurybia, inputs=[context_input, question_input], outputs=answer_output) | |
| with gr.Tab("Enhance Image"): | |
| gr.Markdown("Upload an image to enhance (upscale).") | |
| enhance_input = gr.Image(label="Input Image", source="upload", type="pil") | |
| enhanced_output = gr.Image(label="Enhanced Image") | |
| btn_enhance = gr.Button("Enhance") | |
| btn_enhance.click(enhance_image, inputs=enhance_input, outputs=enhanced_output) | |
| with gr.Tab("Depth Prediction"): | |
| gr.Markdown("Upload an image for depth prediction.") | |
| depth_input = gr.Image(label="Input Image", source="upload", type="pil") | |
| depth_output = gr.Image(label="Depth Image") | |
| btn_depth = gr.Button("Predict Depth") | |
| btn_depth.click(predict_depth, inputs=depth_input, outputs=depth_output) | |
| demo.launch() | |