|
|
|
|
|
|
|
|
|
|
|
|
|
import subprocess |
|
import sys |
|
import os |
|
import io |
|
from typing import List, Dict, Tuple, Optional |
|
import json |
|
import re |
|
import hashlib |
|
import time |
|
|
|
|
|
def install_packages(): |
|
"""Install all required packages for the document intelligence system""" |
|
packages = [ |
|
'gradio', |
|
'transformers', |
|
'torch', |
|
'torchvision', |
|
'Pillow', |
|
'pytesseract', |
|
'pdf2image', |
|
'opencv-python', |
|
'sentencepiece', |
|
'accelerate' |
|
] |
|
|
|
print("Installing required packages...") |
|
for package in packages: |
|
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '-q']) |
|
|
|
|
|
print("Installing system dependencies...") |
|
subprocess.check_call(['apt-get', 'update', '-qq']) |
|
subprocess.check_call(['apt-get', 'install', '-y', '-qq', 'poppler-utils', 'tesseract-ocr']) |
|
|
|
|
|
try: |
|
import gradio as gr |
|
from transformers import ( |
|
AutoProcessor, AutoModelForTokenClassification, |
|
AutoTokenizer, AutoModelForSeq2SeqLM, |
|
pipeline |
|
) |
|
import torch |
|
from PIL import Image |
|
import pytesseract |
|
from pdf2image import convert_from_path |
|
import cv2 |
|
import numpy as np |
|
except ImportError: |
|
print("Installing required packages...") |
|
install_packages() |
|
|
|
import gradio as gr |
|
from transformers import ( |
|
AutoProcessor, AutoModelForTokenClassification, |
|
AutoTokenizer, AutoModelForSeq2SeqLM, |
|
pipeline |
|
) |
|
import torch |
|
from PIL import Image |
|
import pytesseract |
|
from pdf2image import convert_from_path |
|
import cv2 |
|
import numpy as np |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
print("Loading models...") |
|
|
|
|
|
print("Loading LayoutLMv3...") |
|
layoutlm_processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) |
|
layoutlm_model = AutoModelForTokenClassification.from_pretrained( |
|
"microsoft/layoutlmv3-base", |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
).to(device) |
|
layoutlm_model.eval() |
|
|
|
|
|
print("Loading T5 model for summarization and Q&A...") |
|
t5_model_name = "google/flan-t5-base" |
|
t5_tokenizer = AutoTokenizer.from_pretrained(t5_model_name) |
|
t5_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
t5_model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
).to(device) |
|
t5_model.eval() |
|
|
|
print("Models loaded successfully!") |
|
|
|
class DocumentProcessor: |
|
""" |
|
Main document processing class that handles OCR, text extraction, |
|
summarization, and question answering for various document types. |
|
""" |
|
|
|
def __init__(self): |
|
"""Initialize the document processor with empty state""" |
|
self.extracted_text = "" |
|
self.document_metadata = {} |
|
self.page_contents = [] |
|
self.processing_cache = {} |
|
|
|
def _get_file_hash(self, file_path: str) -> str: |
|
"""Generate a hash for the file to use as cache key""" |
|
with open(file_path, 'rb') as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
def process_pdf(self, pdf_path: str, max_pages: int = 20) -> List[Image.Image]: |
|
""" |
|
Convert PDF pages to images for OCR processing |
|
|
|
Args: |
|
pdf_path: Path to the PDF file |
|
max_pages: Maximum number of pages to process (for memory management) |
|
|
|
Returns: |
|
List of PIL Images representing PDF pages |
|
""" |
|
try: |
|
|
|
images = convert_from_path( |
|
pdf_path, |
|
dpi=150, |
|
first_page=1, |
|
last_page=min(max_pages, 100) |
|
) |
|
return images |
|
except Exception as e: |
|
print(f"Error processing PDF: {e}") |
|
return [] |
|
|
|
def extract_text_from_image(self, image: Image.Image) -> Dict[str, any]: |
|
""" |
|
Extract text and layout information from an image using OCR |
|
|
|
Args: |
|
image: PIL Image to process |
|
|
|
Returns: |
|
Dictionary containing extracted text and metadata |
|
""" |
|
try: |
|
|
|
max_dimension = 2000 |
|
if max(image.size) > max_dimension: |
|
ratio = max_dimension / max(image.size) |
|
new_size = tuple(int(dim * ratio) for dim in image.size) |
|
image = image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
image_np = np.array(image) |
|
|
|
|
|
ocr_config = '--oem 3 --psm 6' |
|
ocr_data = pytesseract.image_to_data( |
|
image_np, |
|
output_type=pytesseract.Output.DICT, |
|
config=ocr_config |
|
) |
|
|
|
|
|
words = [] |
|
boxes = [] |
|
confidences = [] |
|
|
|
for i in range(len(ocr_data['text'])): |
|
if ocr_data['text'][i].strip() and ocr_data['conf'][i] > 30: |
|
words.append(ocr_data['text'][i]) |
|
boxes.append([ |
|
ocr_data['left'][i], |
|
ocr_data['top'][i], |
|
ocr_data['left'][i] + ocr_data['width'][i], |
|
ocr_data['top'][i] + ocr_data['height'][i] |
|
]) |
|
confidences.append(ocr_data['conf'][i]) |
|
|
|
|
|
text = ' '.join(words) |
|
|
|
|
|
structured_text = text |
|
if words and len(words) < 400: |
|
try: |
|
|
|
encoding = layoutlm_processor( |
|
image, |
|
words[:400], |
|
boxes=boxes[:400], |
|
return_tensors="pt", |
|
truncation=True, |
|
padding="max_length", |
|
max_length=512 |
|
) |
|
|
|
|
|
encoding = {k: v.to(device) for k, v in encoding.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = layoutlm_model(**encoding) |
|
|
|
|
|
predictions = outputs.logits.argmax(-1).squeeze().tolist() |
|
if isinstance(predictions, int): |
|
predictions = [predictions] |
|
|
|
|
|
structured_text = self._structure_text(words[:len(predictions)], boxes[:len(predictions)]) |
|
except Exception as e: |
|
print(f"LayoutLM processing skipped: {e}") |
|
structured_text = self._simple_structure_text(words, boxes) |
|
else: |
|
structured_text = self._simple_structure_text(words, boxes) |
|
|
|
return { |
|
'raw_text': text, |
|
'words': words, |
|
'boxes': boxes, |
|
'structured_text': structured_text, |
|
'num_words': len(words), |
|
'avg_confidence': sum(confidences) / len(confidences) if confidences else 0 |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error extracting text: {e}") |
|
return { |
|
'raw_text': "", |
|
'words': [], |
|
'boxes': [], |
|
'structured_text': "", |
|
'num_words': 0, |
|
'avg_confidence': 0 |
|
} |
|
|
|
def _simple_structure_text(self, words: List[str], boxes: List[List[int]]) -> str: |
|
""" |
|
Simple text structuring based on spatial layout |
|
Groups words into lines based on vertical position |
|
""" |
|
if not words: |
|
return "" |
|
|
|
|
|
lines = [] |
|
current_line = [] |
|
last_y = None |
|
|
|
for word, box in zip(words, boxes): |
|
y_pos = box[1] |
|
|
|
if last_y is None or abs(y_pos - last_y) < 15: |
|
current_line.append(word) |
|
else: |
|
if current_line: |
|
lines.append(' '.join(current_line)) |
|
current_line = [word] |
|
|
|
last_y = y_pos |
|
|
|
if current_line: |
|
lines.append(' '.join(current_line)) |
|
|
|
return '\n'.join(lines) |
|
|
|
def _structure_text(self, words: List[str], boxes: List[List[int]]) -> str: |
|
"""Enhanced text structuring with better line detection""" |
|
return self._simple_structure_text(words, boxes) |
|
|
|
def process_document(self, file_path: str) -> str: |
|
""" |
|
Process any document type (PDF or image) and extract text |
|
|
|
Args: |
|
file_path: Path to the document file |
|
|
|
Returns: |
|
Status message indicating success or failure |
|
""" |
|
|
|
self.extracted_text = "" |
|
self.page_contents = [] |
|
self.document_metadata = { |
|
'filename': os.path.basename(file_path), |
|
'pages': 0, |
|
'total_words': 0 |
|
} |
|
|
|
|
|
file_hash = self._get_file_hash(file_path) |
|
if file_hash in self.processing_cache: |
|
cached_data = self.processing_cache[file_hash] |
|
self.extracted_text = cached_data['text'] |
|
self.page_contents = cached_data['pages'] |
|
self.document_metadata = cached_data['metadata'] |
|
return f"β
Loaded from cache: {self.document_metadata['filename']}\n" \ |
|
f"π Pages: {self.document_metadata['pages']}\n" \ |
|
f"π Words: {self.document_metadata['total_words']}" |
|
|
|
try: |
|
start_time = time.time() |
|
|
|
if file_path.lower().endswith('.pdf'): |
|
|
|
images = self.process_pdf(file_path) |
|
self.document_metadata['pages'] = len(images) |
|
|
|
for i, image in enumerate(images): |
|
print(f"Processing page {i+1}/{len(images)}...") |
|
page_data = self.extract_text_from_image(image) |
|
self.page_contents.append(page_data) |
|
self.extracted_text += f"\n\n--- Page {i+1} ---\n\n" |
|
self.extracted_text += page_data['structured_text'] |
|
self.document_metadata['total_words'] += page_data['num_words'] |
|
|
|
else: |
|
|
|
image = Image.open(file_path).convert('RGB') |
|
page_data = self.extract_text_from_image(image) |
|
self.page_contents.append(page_data) |
|
self.extracted_text = page_data['structured_text'] |
|
self.document_metadata['pages'] = 1 |
|
self.document_metadata['total_words'] = page_data['num_words'] |
|
|
|
|
|
self.processing_cache[file_hash] = { |
|
'text': self.extracted_text, |
|
'pages': self.page_contents, |
|
'metadata': self.document_metadata |
|
} |
|
|
|
processing_time = time.time() - start_time |
|
|
|
if self.document_metadata['total_words'] == 0: |
|
return f"β οΈ No text found in {self.document_metadata['filename']}. Please ensure the document contains readable text." |
|
|
|
return f"β
Successfully processed {self.document_metadata['filename']}\n" \ |
|
f"π Pages: {self.document_metadata['pages']}\n" \ |
|
f"π Words extracted: {self.document_metadata['total_words']}\n" \ |
|
f"β±οΈ Processing time: {processing_time:.1f}s" |
|
|
|
except Exception as e: |
|
return f"β Error processing document: {str(e)}" |
|
|
|
def summarize_document(self) -> str: |
|
""" |
|
Generate a concise summary of the document using T5 model |
|
|
|
Returns: |
|
Document summary or error message |
|
""" |
|
if not self.extracted_text: |
|
return "No document has been processed yet. Please upload and process a document first." |
|
|
|
try: |
|
start_time = time.time() |
|
|
|
|
|
text_to_summarize = self.extracted_text[:2048] |
|
|
|
|
|
prompt = f"Summarize the following document:\n\n{text_to_summarize}" |
|
|
|
|
|
inputs = t5_tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
max_length=1024, |
|
truncation=True |
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
summary_ids = t5_model.generate( |
|
inputs.input_ids, |
|
max_length=150, |
|
min_length=30, |
|
num_beams=4, |
|
length_penalty=2.0, |
|
early_stopping=True |
|
) |
|
|
|
|
|
summary = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
generation_time = time.time() - start_time |
|
|
|
return f"{summary}\n\nβ±οΈ Generated in {generation_time:.1f}s" |
|
|
|
except Exception as e: |
|
return f"Error generating summary: {str(e)}" |
|
|
|
def answer_question(self, question: str) -> str: |
|
""" |
|
Answer questions about the document using T5 model |
|
|
|
Args: |
|
question: User's question about the document |
|
|
|
Returns: |
|
Answer to the question |
|
""" |
|
if not self.extracted_text: |
|
return "Please upload and process a document first." |
|
|
|
if not question.strip(): |
|
return "Please enter a question." |
|
|
|
try: |
|
start_time = time.time() |
|
|
|
|
|
context = self.extracted_text[:1536] |
|
|
|
|
|
prompt = f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:" |
|
|
|
|
|
inputs = t5_tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
max_length=1024, |
|
truncation=True |
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
answer_ids = t5_model.generate( |
|
inputs.input_ids, |
|
max_length=100, |
|
min_length=5, |
|
num_beams=3, |
|
temperature=0.7, |
|
do_sample=True, |
|
top_p=0.9 |
|
) |
|
|
|
|
|
answer = t5_tokenizer.decode(answer_ids[0], skip_special_tokens=True) |
|
|
|
generation_time = time.time() - start_time |
|
|
|
return f"{answer}\n\nβ±οΈ Generated in {generation_time:.1f}s" |
|
|
|
except Exception as e: |
|
return f"Error answering question: {str(e)}" |
|
|
|
def extract_key_information(self) -> Dict[str, List[str]]: |
|
""" |
|
Extract key entities from the document using regex patterns |
|
|
|
Returns: |
|
Dictionary of extracted entities organized by type |
|
""" |
|
if not self.extracted_text: |
|
return {"message": ["No document has been processed yet."]} |
|
|
|
try: |
|
entities = { |
|
'dates': [], |
|
'emails': [], |
|
'phone_numbers': [], |
|
'monetary_amounts': [], |
|
'percentages': [], |
|
'urls': [] |
|
} |
|
|
|
|
|
date_patterns = [ |
|
r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', |
|
r'\b\d{4}[/-]\d{1,2}[/-]\d{1,2}\b', |
|
r'\b(?:January|February|March|April|May|June|July|August|September|October|November|December|Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{1,2},?\s+\d{4}\b', |
|
r'\b\d{1,2}\s+(?:January|February|March|April|May|June|July|August|September|October|November|December|Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{4}\b' |
|
] |
|
|
|
for pattern in date_patterns: |
|
matches = re.findall(pattern, self.extracted_text, re.IGNORECASE) |
|
entities['dates'].extend(matches) |
|
|
|
|
|
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b' |
|
entities['emails'] = re.findall(email_pattern, self.extracted_text) |
|
|
|
|
|
phone_patterns = [ |
|
r'\b\+?1?\s*\(?([0-9]{3})\)?[-.\s]?([0-9]{3})[-.\s]?([0-9]{4})\b', |
|
r'\b\d{3}[-.\s]\d{3}[-.\s]\d{4}\b' |
|
] |
|
|
|
for pattern in phone_patterns: |
|
matches = re.findall(pattern, self.extracted_text) |
|
if isinstance(matches[0], tuple) if matches else False: |
|
entities['phone_numbers'].extend(['-'.join(match) for match in matches]) |
|
else: |
|
entities['phone_numbers'].extend(matches) |
|
|
|
|
|
money_patterns = [ |
|
r'\$\s*[\d,]+\.?\d*', |
|
r'USD\s*[\d,]+\.?\d*', |
|
r'\b\d{1,3}(?:,\d{3})*(?:\.\d{2})?\s*(?:dollars?|USD)\b' |
|
] |
|
|
|
for pattern in money_patterns: |
|
matches = re.findall(pattern, self.extracted_text, re.IGNORECASE) |
|
entities['monetary_amounts'].extend(matches) |
|
|
|
|
|
percent_pattern = r'\b\d+\.?\d*\s*%' |
|
entities['percentages'] = re.findall(percent_pattern, self.extracted_text) |
|
|
|
|
|
url_pattern = r'https?://(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&/=]*)' |
|
entities['urls'] = re.findall(url_pattern, self.extracted_text) |
|
|
|
|
|
for key in entities: |
|
|
|
unique_items = list(dict.fromkeys(entities[key])) |
|
entities[key] = unique_items[:10] |
|
|
|
|
|
entities = {k: v for k, v in entities.items() if v} |
|
|
|
if not entities: |
|
entities = {"info": ["No specific entities found. The document may need better quality or contain different types of information."]} |
|
|
|
return entities |
|
|
|
except Exception as e: |
|
return {"error": [f"Error extracting information: {str(e)}"]} |
|
|
|
|
|
processor = DocumentProcessor() |
|
|
|
|
|
def process_document_handler(file): |
|
"""Handle document upload and processing""" |
|
if file is None: |
|
return "Please upload a document.", "", {} |
|
|
|
|
|
status = processor.process_document(file) |
|
|
|
|
|
text_preview = processor.extracted_text[:1000] + "..." if len(processor.extracted_text) > 1000 else processor.extracted_text |
|
|
|
|
|
key_info = processor.extract_key_information() |
|
|
|
return status, text_preview, key_info |
|
|
|
def summarize_handler(): |
|
"""Handle document summarization request""" |
|
return processor.summarize_document() |
|
|
|
def qa_handler(question): |
|
"""Handle question answering request""" |
|
if not question: |
|
return "Please enter a question." |
|
return processor.answer_question(question) |
|
|
|
def create_interface(): |
|
""" |
|
Create the Gradio interface for the document intelligence system |
|
""" |
|
|
|
with gr.Blocks(title="Multi-Modal Document Intelligence System", theme=gr.themes.Soft()) as interface: |
|
|
|
gr.Markdown(""" |
|
# π§ Multi-Modal Document Intelligence System |
|
|
|
**Upload any document (PDF or image) and unlock its insights with AI!** |
|
|
|
This advanced system combines: |
|
- π **LayoutLMv3** for understanding document structure and layout |
|
- π€ **Flan-T5** for intelligent summarization and question answering |
|
- π **OCR Technology** for accurate text extraction from any document |
|
|
|
### β¨ Features |
|
- Upload PDFs or images (JPG, PNG, etc.) |
|
- Automatic text extraction with layout understanding |
|
- Intelligent document summarization |
|
- Natural language Q&A about your documents |
|
- Key information extraction (dates, emails, amounts, etc.) |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
file_input = gr.File( |
|
label="π Upload Document", |
|
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".bmp", ".tiff"], |
|
type="filepath" |
|
) |
|
|
|
process_btn = gr.Button("π Process Document", variant="primary", size="lg") |
|
|
|
status_output = gr.Textbox( |
|
label="π Processing Status", |
|
lines=4, |
|
interactive=False |
|
) |
|
|
|
gr.Markdown("### π Key Information Extracted") |
|
key_info_output = gr.JSON(label="Extracted Entities", elem_id="key_info") |
|
|
|
|
|
with gr.Column(scale=2): |
|
text_preview = gr.Textbox( |
|
label="π Document Text Preview", |
|
lines=10, |
|
max_lines=15, |
|
interactive=False |
|
) |
|
|
|
with gr.Tab("π Summary"): |
|
summary_btn = gr.Button("Generate Summary", variant="secondary") |
|
summary_output = gr.Textbox( |
|
label="Document Summary", |
|
lines=8, |
|
interactive=False |
|
) |
|
|
|
with gr.Tab("β Q&A"): |
|
question_input = gr.Textbox( |
|
label="Ask a question about the document", |
|
placeholder="e.g., What are the main points? What dates are mentioned? What is the total amount?", |
|
lines=2 |
|
) |
|
qa_btn = gr.Button("Get Answer", variant="secondary") |
|
answer_output = gr.Textbox( |
|
label="Answer", |
|
lines=6, |
|
interactive=False |
|
) |
|
|
|
|
|
gr.Markdown("### π Example Questions") |
|
gr.Examples( |
|
examples=[ |
|
"What is the main topic of this document?", |
|
"What dates are mentioned?", |
|
"What is the total amount due?", |
|
"Who are the key people mentioned?", |
|
"What are the main findings?", |
|
"Summarize the key points." |
|
], |
|
inputs=question_input |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
--- |
|
### π― How to Use |
|
1. **Upload** a PDF or image document |
|
2. **Process** the document to extract text |
|
3. **Review** the extracted text and key information |
|
4. **Generate** a summary or ask questions |
|
|
|
### π‘ Tips for Best Results |
|
- Use clear, high-quality documents |
|
- For images, ensure good lighting and contrast |
|
- The system works with multiple languages |
|
- Processing time depends on document size and complexity |
|
|
|
--- |
|
π¨βπ» **Created by Spencer Purdy** | Computer Science @ Auburn University |
|
[GitHub](https://github.com/spencercpurdy) | [LinkedIn](https://linkedin.com/in/spencerpurdy) | [Hugging Face](https://huggingface.co/spencercpurdy) |
|
""") |
|
|
|
|
|
process_btn.click( |
|
fn=process_document_handler, |
|
inputs=file_input, |
|
outputs=[status_output, text_preview, key_info_output] |
|
) |
|
|
|
summary_btn.click( |
|
fn=summarize_handler, |
|
inputs=[], |
|
outputs=summary_output |
|
) |
|
|
|
qa_btn.click( |
|
fn=qa_handler, |
|
inputs=question_input, |
|
outputs=answer_output |
|
) |
|
|
|
|
|
question_input.submit( |
|
fn=qa_handler, |
|
inputs=question_input, |
|
outputs=answer_output |
|
) |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Starting Multi-Modal Document Intelligence System...") |
|
|
|
|
|
interface = create_interface() |
|
|
|
|
|
interface.launch( |
|
debug=True, |
|
share=True, |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |