Spaces:
Running
Running
""" | |
TextLens - AI-Powered OCR Application | |
Main entry point for the application. | |
""" | |
import gradio as gr | |
import torch | |
import time | |
import logging | |
from threading import Thread | |
from PIL import Image | |
from transformers import ( | |
AutoProcessor, | |
AutoModelForCausalLM, | |
TextIteratorStreamer, | |
Qwen2VLForConditionalGeneration, | |
) | |
from transformers import Qwen2_5_VLForConditionalGeneration | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Model configurations | |
QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" | |
ROLMOCR_MODEL_ID = "reducto/RolmOCR" | |
def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str: | |
"""Returns an HTML snippet for a thin animated progress bar with a label.""" | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
# Load models at startup | |
logger.info("π Loading OCR models...") | |
logger.info("This may take a few minutes on first run...") | |
try: | |
# Load Qwen2VL OCR model (primary fast model) | |
logger.info(f"Loading Qwen2VL OCR model: {QV_MODEL_ID}") | |
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True) | |
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
QV_MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to("cuda" if torch.cuda.is_available() else "cpu").eval() | |
logger.info("β Qwen2VL OCR model loaded successfully!") | |
# Load RolmOCR model (specialized document model) | |
logger.info(f"Loading RolmOCR model: {ROLMOCR_MODEL_ID}") | |
rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True) | |
rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
ROLMOCR_MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
).to("cuda" if torch.cuda.is_available() else "cpu").eval() | |
logger.info("β RolmOCR model loaded successfully!") | |
MODELS_LOADED = True | |
logger.info("π All models loaded and ready!") | |
except Exception as e: | |
logger.error(f"β Failed to load models: {str(e)}") | |
MODELS_LOADED = False | |
def extract_text_from_image(image, text_query, use_rolmocr=False): | |
"""Extract text from image using selected OCR model with streaming response.""" | |
if not MODELS_LOADED: | |
yield "β Error: OCR models failed to load. Please check your setup and try again." | |
return | |
if image is None: | |
yield "β No image provided. Please upload an image to extract text." | |
return | |
try: | |
# Ensure image is in RGB format | |
if not isinstance(image, Image.Image): | |
yield "β Invalid image format. Please upload a valid image file." | |
return | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Prepare text query | |
if not text_query.strip(): | |
text_query = "Extract all text from this image" | |
# Select model and processor | |
if use_rolmocr: | |
processor = rolmocr_processor | |
model = rolmocr_model | |
model_name = "RolmOCR" | |
logger.info("Using RolmOCR for specialized document processing") | |
else: | |
processor = qwen_processor | |
model = qwen_model | |
model_name = "Qwen2VL OCR" | |
logger.info("Using Qwen2VL OCR for fast text extraction") | |
# Build messages for the model | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": text_query}, | |
{"type": "image", "image": image} | |
] | |
} | |
] | |
# Apply chat template and prepare inputs | |
prompt_full = processor.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = processor( | |
text=[prompt_full], | |
images=[image], | |
return_tensors="pt", | |
padding=True, | |
).to("cuda" if torch.cuda.is_available() else "cpu") | |
# Set up streaming | |
streamer = TextIteratorStreamer( | |
processor, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generation_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=1024, | |
do_sample=False, | |
temperature=0.1 | |
) | |
# Start generation in separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield progress bar first | |
yield progress_bar_html(f"π Processing with {model_name}") | |
# Stream the response | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
# Clean up any special tokens that might leak through | |
clean_buffer = buffer.replace("<|im_end|>", "").replace("<|endoftext|>", "").strip() | |
if clean_buffer: | |
time.sleep(0.01) # Small delay for smooth streaming | |
yield clean_buffer | |
# Ensure thread completes | |
thread.join() | |
# Final clean response | |
final_response = buffer.replace("<|im_end|>", "").replace("<|endoftext|>", "").strip() | |
if not final_response: | |
yield "β οΈ No text was detected in the image. Please try with a clearer image or different model." | |
else: | |
logger.info(f"β Successfully extracted text: {len(final_response)} characters") | |
yield final_response | |
except Exception as e: | |
error_msg = f"β Error processing image: {str(e)}" | |
logger.error(f"OCR processing failed: {str(e)}") | |
yield error_msg | |
def get_model_status(): | |
"""Get current model status information.""" | |
if MODELS_LOADED: | |
device = "π’ GPU (CUDA)" if torch.cuda.is_available() else "π‘ CPU" | |
return f""" | |
**π€ Model Status: β Ready** | |
**Primary Model:** Qwen2VL-OCR-2B (Fast general OCR) | |
**Secondary Model:** RolmOCR (Specialized documents) | |
**Device:** {device} | |
**Memory:** Optimized for streaming inference | |
β¨ Both models loaded and ready for OCR processing! | |
""" | |
else: | |
return """ | |
**π€ Model Status: β Failed to Load** | |
Please check your internet connection and GPU setup. | |
Models need to be downloaded on first run. | |
""" | |
# Create Gradio Interface | |
def create_interface(): | |
"""Create the streamlined OCR interface.""" | |
with gr.Blocks( | |
title="TextLens - Fast AI OCR", | |
theme=gr.themes.Soft(), | |
css=""" | |
.container { max-width: 1200px; margin: auto; } | |
.header { text-align: center; padding: 20px; } | |
.model-status { background: #f0f0f0; padding: 15px; border-radius: 8px; margin: 10px 0; } | |
""" | |
) as interface: | |
# Header | |
gr.HTML(""" | |
<div class="header"> | |
<h1>π TextLens - AI-Powered OCR</h1> | |
<p style="font-size: 16px; color: #666;"> | |
Fast and accurate text extraction using modern AI models | |
</p> | |
</div> | |
""") | |
# Model Status | |
with gr.Row(): | |
with gr.Column(): | |
status_display = gr.Markdown( | |
value=get_model_status(), | |
elem_classes=["model-status"] | |
) | |
refresh_btn = gr.Button("π Refresh Status", size="sm") | |
# Main Interface | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### π Upload Image") | |
image_input = gr.Image( | |
label="Upload image for OCR", | |
type="pil", | |
sources=["upload", "clipboard"] | |
) | |
text_query = gr.Textbox( | |
label="π OCR Instructions (optional)", | |
placeholder="Extract all text from this image", | |
value="Extract all text from this image", | |
lines=2 | |
) | |
use_rolmocr = gr.Checkbox( | |
label="π― Use RolmOCR (specialized for documents)", | |
value=False, | |
info="Check for complex documents/tables, uncheck for general text" | |
) | |
extract_btn = gr.Button( | |
"π Extract Text", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("### π Extracted Text") | |
text_output = gr.Textbox( | |
label="OCR Results", | |
lines=15, | |
max_lines=25, | |
placeholder="Extracted text will appear here...\n\nβ’ Upload an image to get started\nβ’ Choose between fast OCR or specialized document processing\nβ’ Results will stream in real-time", | |
show_copy_button=True | |
) | |
# Event handlers | |
extract_btn.click( | |
fn=extract_text_from_image, | |
inputs=[image_input, text_query, use_rolmocr], | |
outputs=text_output, | |
show_progress="hidden" # We handle progress with custom HTML | |
) | |
# Auto-extract on image upload | |
image_input.upload( | |
fn=extract_text_from_image, | |
inputs=[image_input, text_query, use_rolmocr], | |
outputs=text_output, | |
show_progress="hidden" | |
) | |
refresh_btn.click( | |
fn=get_model_status, | |
outputs=status_display | |
) | |
return interface | |
if __name__ == "__main__": | |
logger.info("π Starting TextLens OCR application...") | |
try: | |
interface = create_interface() | |
# Launch configuration | |
interface.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
debug=False | |
) | |
except Exception as e: | |
logger.error(f"Failed to start application: {str(e)}") | |
raise |