|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from transformers import LlavaForConditionalGeneration, AutoProcessor |
|
import logging |
|
import json |
|
import os |
|
from datetime import datetime |
|
import uuid |
|
import spacy |
|
from spacy.cli import download |
|
import zipfile |
|
import shutil |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
OUTPUT_JSON_PATH = "captions.json" |
|
UPLOAD_DIR = "uploads" |
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
|
|
try: |
|
try: |
|
nlp = spacy.load("en_core_web_sm") |
|
except OSError: |
|
logger.info("Downloading en_core_web_sm model...") |
|
download("en_core_web_sm") |
|
nlp = spacy.load("en_core_web_sm") |
|
except Exception as e: |
|
logger.error(f"Error loading SpaCy model: {str(e)}") |
|
raise |
|
|
|
|
|
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" |
|
try: |
|
processor = AutoProcessor.from_pretrained(MODEL_PATH) |
|
model = LlavaForConditionalGeneration.from_pretrained( |
|
MODEL_PATH, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
device_map="auto" |
|
) |
|
model.eval() |
|
logger.info("Model and processor loaded successfully.") |
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
raise |
|
|
|
|
|
def extract_keywords(text): |
|
try: |
|
doc = nlp(text) |
|
keywords = [token.text.lower() for token in doc if token.pos_ in ["NOUN", "ADJ"] and not token.is_stop] |
|
return list(set(keywords))[:5] |
|
except Exception as e: |
|
logger.error(f"Error extracting keywords: {str(e)}") |
|
return [] |
|
|
|
|
|
def save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None): |
|
result = { |
|
"image_name": image_name, |
|
"caption": caption, |
|
"caption_type": caption_type, |
|
"custom_prompt": custom_prompt, |
|
"keywords": keywords, |
|
"timestamp": datetime.now().isoformat(), |
|
"error": error |
|
} |
|
try: |
|
if os.path.exists(OUTPUT_JSON_PATH): |
|
with open(OUTPUT_JSON_PATH, "r") as f: |
|
data = json.load(f) |
|
else: |
|
data = [] |
|
except Exception as e: |
|
logger.error(f"Error reading JSON file: {str(e)}") |
|
data = [] |
|
|
|
data.append(result) |
|
try: |
|
with open(OUTPUT_JSON_PATH, "w") as f: |
|
json.dump(data, f, indent=4) |
|
logger.info(f"Saved result to {OUTPUT_JSON_PATH}") |
|
except Exception as e: |
|
logger.error(f"Error writing to JSON file: {str(e)}") |
|
|
|
|
|
def process_single_image(image, caption_type, custom_prompt): |
|
if image is None: |
|
error_msg = "Please upload an image." |
|
save_to_json("unknown", error_msg, caption_type, custom_prompt, [], error=error_msg) |
|
return error_msg |
|
|
|
image_name = os.path.join(UPLOAD_DIR, f"image_{uuid.uuid4().hex}.jpg") |
|
image.save(image_name) |
|
|
|
try: |
|
image = image.resize((256, 256)) |
|
prompt = custom_prompt.strip() if custom_prompt.strip() else f"Write a {caption_type} caption for this image." |
|
convo = [ |
|
{"role": "system", "content": "You are a helpful assistant that generates accurate and relevant image captions."}, |
|
{"role": "user", "content": prompt.strip()} |
|
] |
|
|
|
inputs = processor(images=image, text=convo[1]["content"], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9) |
|
|
|
caption = processor.decode(output[0], skip_special_tokens=True).strip() |
|
keywords = extract_keywords(caption) |
|
|
|
save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None) |
|
return f"Caption: {caption}\nKeywords: {', '.join(keywords)}" |
|
except Exception as e: |
|
error_msg = f"Error generating caption: {str(e)}" |
|
logger.error(error_msg) |
|
save_to_json(image_name, "", caption_type, custom_prompt, [], error=error_msg) |
|
return error_msg |
|
|
|
|
|
def process_batch_images(zip_file, caption_type, custom_prompt): |
|
if zip_file is None: |
|
return "Please upload a zip file." |
|
|
|
temp_dir = "temp_upload" |
|
os.makedirs(temp_dir, exist_ok=True) |
|
results = [] |
|
|
|
try: |
|
with zipfile.ZipFile(zip_file.name, "r") as zip_ref: |
|
zip_ref.extractall(temp_dir) |
|
|
|
for root, _, files in os.walk(temp_dir): |
|
for file in files: |
|
if file.lower().endswith((".jpg", ".jpeg", ".png")): |
|
image_path = os.path.join(root, file) |
|
image_name = os.path.join(UPLOAD_DIR, f"image_{uuid.uuid4().hex}.jpg") |
|
shutil.copy(image_path, image_name) |
|
|
|
try: |
|
image = Image.open(image_path).convert("RGB").resize((256, 256)) |
|
prompt = custom_prompt.strip() if custom_prompt.strip() else f"Write a {caption_type} caption for this image." |
|
convo = [ |
|
{"role": "system", "content": "You are a helpful assistant that generates accurate and relevant image captions."}, |
|
{"role": "user", "content": prompt.strip()} |
|
] |
|
|
|
inputs = processor(images=image, text=convo[1]["content"], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9) |
|
|
|
caption = processor.decode(output[0], skip_special_tokens=True).strip() |
|
keywords = extract_keywords(caption) |
|
|
|
save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None) |
|
results.append(f"Image: {image_name}\nCaption: {caption}\nKeywords: {', '.join(keywords)}") |
|
except Exception as e: |
|
error_msg = f"Error processing {image_path}: {str(e)}" |
|
logger.error(error_msg) |
|
save_to_json(image_name, "", caption_type, custom_prompt, [], error=error_msg) |
|
results.append(error_msg) |
|
|
|
shutil.rmtree(temp_dir) |
|
return "\n\n".join(results) |
|
except Exception as e: |
|
error_msg = f"Error processing batch: {str(e)}" |
|
logger.error(error_msg) |
|
return error_msg |
|
|
|
|
|
def search_images(query): |
|
try: |
|
if not os.path.exists(OUTPUT_JSON_PATH): |
|
return "No captions available." |
|
|
|
with open(OUTPUT_JSON_PATH, "r") as f: |
|
data = json.load(f) |
|
|
|
results = [] |
|
for entry in data: |
|
if query.lower() in entry["caption"].lower() or any(query.lower() in kw.lower() for kw in entry["keywords"]): |
|
results.append((entry["image_name"], f"Caption: {entry['caption']}\nKeywords: {', '.join(entry['keywords'])}")) |
|
|
|
return results if results else "No matches found." |
|
except Exception as e: |
|
logger.error(f"Error searching images: {str(e)}") |
|
return f"Error searching images: {str(e)}" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=[process_single_image, process_batch_images, search_images], |
|
inputs=[ |
|
[gr.Image(label="Upload Single Image", type="pil"), gr.Dropdown(choices=["descriptive", "poetic", "humorous"], label="Caption Style", value="descriptive"), gr.Textbox(label="Custom Prompt (optional)", placeholder="e.g., 'Write a poetic caption'")], |
|
[gr.File(label="Upload Zip File for Batch Processing", file_types=[".zip"]), gr.Dropdown(choices=["descriptive", "poetic", "humorous"], label="Caption Style", value="descriptive"), gr.Textbox(label="Custom Prompt (optional)", placeholder="e.g., 'Write a poetic caption'")], |
|
[gr.Textbox(label="Search Query", placeholder="e.g., 'beach'")] |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Single Image Result"), |
|
gr.Textbox(label="Batch Processing Results"), |
|
gr.Gallery(label="Search Results") |
|
], |
|
title="Image Captioning with LLAVA", |
|
description="Upload single or batch images, generate captions with custom styles, and search by captions or keywords. Results are saved to captions.json." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|