File size: 8,753 Bytes
7377f5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import gradio as gr
import torch
from PIL import Image
from transformers import LlavaForConditionalGeneration, AutoProcessor
import logging
import json
import os
from datetime import datetime
import uuid
import spacy
from spacy.cli import download
import zipfile
import shutil

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define paths
OUTPUT_JSON_PATH = "captions.json"
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)

# Load SpaCy model for keyword extraction
try:
    try:
        nlp = spacy.load("en_core_web_sm")
    except OSError:
        logger.info("Downloading en_core_web_sm model...")
        download("en_core_web_sm")
        nlp = spacy.load("en_core_web_sm")
except Exception as e:
    logger.error(f"Error loading SpaCy model: {str(e)}")
    raise

# Load LLAVA model and processor
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
try:
    processor = AutoProcessor.from_pretrained(MODEL_PATH)
    model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    model.eval()
    logger.info("Model and processor loaded successfully.")
except Exception as e:
    logger.error(f"Error loading model: {str(e)}")
    raise

# Function to extract keywords
def extract_keywords(text):
    try:
        doc = nlp(text)
        keywords = [token.text.lower() for token in doc if token.pos_ in ["NOUN", "ADJ"] and not token.is_stop]
        return list(set(keywords))[:5]
    except Exception as e:
        logger.error(f"Error extracting keywords: {str(e)}")
        return []

# Function to save metadata to JSON
def save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None):
    result = {
        "image_name": image_name,
        "caption": caption,
        "caption_type": caption_type,
        "custom_prompt": custom_prompt,
        "keywords": keywords,
        "timestamp": datetime.now().isoformat(),
        "error": error
    }
    try:
        if os.path.exists(OUTPUT_JSON_PATH):
            with open(OUTPUT_JSON_PATH, "r") as f:
                data = json.load(f)
        else:
            data = []
    except Exception as e:
        logger.error(f"Error reading JSON file: {str(e)}")
        data = []
    
    data.append(result)
    try:
        with open(OUTPUT_JSON_PATH, "w") as f:
            json.dump(data, f, indent=4)
        logger.info(f"Saved result to {OUTPUT_JSON_PATH}")
    except Exception as e:
        logger.error(f"Error writing to JSON file: {str(e)}")

# Function to process single image
def process_single_image(image, caption_type, custom_prompt):
    if image is None:
        error_msg = "Please upload an image."
        save_to_json("unknown", error_msg, caption_type, custom_prompt, [], error=error_msg)
        return error_msg
    
    image_name = os.path.join(UPLOAD_DIR, f"image_{uuid.uuid4().hex}.jpg")
    image.save(image_name)
    
    try:
        image = image.resize((256, 256))
        prompt = custom_prompt.strip() if custom_prompt.strip() else f"Write a {caption_type} caption for this image."
        convo = [
            {"role": "system", "content": "You are a helpful assistant that generates accurate and relevant image captions."},
            {"role": "user", "content": prompt.strip()}
        ]
        
        inputs = processor(images=image, text=convo[1]["content"], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
        
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9)
        
        caption = processor.decode(output[0], skip_special_tokens=True).strip()
        keywords = extract_keywords(caption)
        
        save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None)
        return f"Caption: {caption}\nKeywords: {', '.join(keywords)}"
    except Exception as e:
        error_msg = f"Error generating caption: {str(e)}"
        logger.error(error_msg)
        save_to_json(image_name, "", caption_type, custom_prompt, [], error=error_msg)
        return error_msg

# Function to process batch images
def process_batch_images(zip_file, caption_type, custom_prompt):
    if zip_file is None:
        return "Please upload a zip file."
    
    temp_dir = "temp_upload"
    os.makedirs(temp_dir, exist_ok=True)
    results = []
    
    try:
        with zipfile.ZipFile(zip_file.name, "r") as zip_ref:
            zip_ref.extractall(temp_dir)
        
        for root, _, files in os.walk(temp_dir):
            for file in files:
                if file.lower().endswith((".jpg", ".jpeg", ".png")):
                    image_path = os.path.join(root, file)
                    image_name = os.path.join(UPLOAD_DIR, f"image_{uuid.uuid4().hex}.jpg")
                    shutil.copy(image_path, image_name)
                    
                    try:
                        image = Image.open(image_path).convert("RGB").resize((256, 256))
                        prompt = custom_prompt.strip() if custom_prompt.strip() else f"Write a {caption_type} caption for this image."
                        convo = [
                            {"role": "system", "content": "You are a helpful assistant that generates accurate and relevant image captions."},
                            {"role": "user", "content": prompt.strip()}
                        ]
                        
                        inputs = processor(images=image, text=convo[1]["content"], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
                        
                        with torch.no_grad():
                            output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9)
                        
                        caption = processor.decode(output[0], skip_special_tokens=True).strip()
                        keywords = extract_keywords(caption)
                        
                        save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None)
                        results.append(f"Image: {image_name}\nCaption: {caption}\nKeywords: {', '.join(keywords)}")
                    except Exception as e:
                        error_msg = f"Error processing {image_path}: {str(e)}"
                        logger.error(error_msg)
                        save_to_json(image_name, "", caption_type, custom_prompt, [], error=error_msg)
                        results.append(error_msg)
        
        shutil.rmtree(temp_dir)
        return "\n\n".join(results)
    except Exception as e:
        error_msg = f"Error processing batch: {str(e)}"
        logger.error(error_msg)
        return error_msg

# Function to search images
def search_images(query):
    try:
        if not os.path.exists(OUTPUT_JSON_PATH):
            return "No captions available."
        
        with open(OUTPUT_JSON_PATH, "r") as f:
            data = json.load(f)
        
        results = []
        for entry in data:
            if query.lower() in entry["caption"].lower() or any(query.lower() in kw.lower() for kw in entry["keywords"]):
                results.append((entry["image_name"], f"Caption: {entry['caption']}\nKeywords: {', '.join(entry['keywords'])}"))
        
        return results if results else "No matches found."
    except Exception as e:
        logger.error(f"Error searching images: {str(e)}")
        return f"Error searching images: {str(e)}"

# Gradio interface
interface = gr.Interface(
    fn=[process_single_image, process_batch_images, search_images],
    inputs=[
        [gr.Image(label="Upload Single Image", type="pil"), gr.Dropdown(choices=["descriptive", "poetic", "humorous"], label="Caption Style", value="descriptive"), gr.Textbox(label="Custom Prompt (optional)", placeholder="e.g., 'Write a poetic caption'")],
        [gr.File(label="Upload Zip File for Batch Processing", file_types=[".zip"]), gr.Dropdown(choices=["descriptive", "poetic", "humorous"], label="Caption Style", value="descriptive"), gr.Textbox(label="Custom Prompt (optional)", placeholder="e.g., 'Write a poetic caption'")],
        [gr.Textbox(label="Search Query", placeholder="e.g., 'beach'")]
    ],
    outputs=[
        gr.Textbox(label="Single Image Result"),
        gr.Textbox(label="Batch Processing Results"),
        gr.Gallery(label="Search Results")
    ],
    title="Image Captioning with LLAVA",
    description="Upload single or batch images, generate captions with custom styles, and search by captions or keywords. Results are saved to captions.json."
)

if __name__ == "__main__":
    interface.launch()