Spaces:
Running
Running
Commit
·
f825473
1
Parent(s):
0545f86
Update: added handwritten extraction and product list storage
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
4 |
from PIL import Image
|
5 |
import io
|
6 |
import json
|
@@ -22,6 +22,9 @@ from torch.utils.tensorboard import SummaryWriter
|
|
22 |
import matplotlib.pyplot as plt
|
23 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
24 |
import matplotlib
|
|
|
|
|
|
|
25 |
matplotlib.use('Agg')
|
26 |
|
27 |
# Configure logging
|
@@ -94,32 +97,45 @@ def extract_fields(image_path):
|
|
94 |
return results
|
95 |
|
96 |
def extract_products(text):
|
97 |
-
#
|
98 |
-
|
|
|
99 |
matches = re.findall(product_pattern, text)
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
return products
|
102 |
|
103 |
def extract_with_perplexity_llm(ocr_text):
|
104 |
prompt = f"""
|
105 |
You are an expert at extracting structured data from receipts.
|
106 |
|
107 |
-
From the following OCR text, extract these fields and return them as a
|
108 |
- name (customer name)
|
109 |
- date (date of purchase)
|
110 |
-
- amount_paid (total amount paid
|
111 |
- receipt_no (receipt number)
|
112 |
-
-
|
113 |
-
|
114 |
-
**Note:** If the receipt has only one product, set 'product' to its name and 'amount_paid' to its price. If there is a 'price' and an 'amount paid', treat them as the same if they are equal.
|
115 |
|
116 |
Example output:
|
117 |
{{
|
118 |
"name": "Mrs. Genevieve Lopez",
|
119 |
"date": "12/13/2024",
|
120 |
-
"amount_paid":
|
121 |
"receipt_no": "042085",
|
122 |
-
"
|
|
|
|
|
|
|
123 |
}}
|
124 |
|
125 |
Text:
|
@@ -142,9 +158,32 @@ Text:
|
|
142 |
)
|
143 |
return response.choices[0].message.content
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def save_to_dynamodb(data, table_name="Receipts"):
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
table.put_item(Item=data)
|
149 |
|
150 |
def merge_extractions(regex_fields, llm_fields):
|
@@ -312,5 +351,19 @@ def main():
|
|
312 |
else:
|
313 |
st.info("Confusion matrix not found.")
|
314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
if __name__ == "__main__":
|
316 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification, TrOCRProcessor, VisionEncoderDecoderModel
|
4 |
from PIL import Image
|
5 |
import io
|
6 |
import json
|
|
|
22 |
import matplotlib.pyplot as plt
|
23 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
24 |
import matplotlib
|
25 |
+
import boto3
|
26 |
+
from decimal import Decimal
|
27 |
+
import uuid
|
28 |
matplotlib.use('Agg')
|
29 |
|
30 |
# Configure logging
|
|
|
97 |
return results
|
98 |
|
99 |
def extract_products(text):
|
100 |
+
# Pattern to match product lines with quantity, name, and price
|
101 |
+
# Example: "2 PISTACHIO 14.49" or "1076903 PISTACHIO 14.49"
|
102 |
+
product_pattern = r"(?:(\d+)\s+)?([A-Z0-9 ]+)\s+(\d+\.\d{2})"
|
103 |
matches = re.findall(product_pattern, text)
|
104 |
+
|
105 |
+
products = []
|
106 |
+
for match in matches:
|
107 |
+
quantity, name, price = match
|
108 |
+
product = {
|
109 |
+
"name": name.strip(),
|
110 |
+
"price": float(price),
|
111 |
+
"quantity": int(quantity) if quantity else 1,
|
112 |
+
"total": float(price) * (int(quantity) if quantity else 1)
|
113 |
+
}
|
114 |
+
products.append(product)
|
115 |
+
|
116 |
return products
|
117 |
|
118 |
def extract_with_perplexity_llm(ocr_text):
|
119 |
prompt = f"""
|
120 |
You are an expert at extracting structured data from receipts.
|
121 |
|
122 |
+
From the following OCR text, extract these fields and return them as a JSON object with exactly these keys:
|
123 |
- name (customer name)
|
124 |
- date (date of purchase)
|
125 |
+
- amount_paid (total amount paid)
|
126 |
- receipt_no (receipt number)
|
127 |
+
- products (a list of all products, each with name, price, and quantity if available)
|
|
|
|
|
128 |
|
129 |
Example output:
|
130 |
{{
|
131 |
"name": "Mrs. Genevieve Lopez",
|
132 |
"date": "12/13/2024",
|
133 |
+
"amount_paid": 29.69,
|
134 |
"receipt_no": "042085",
|
135 |
+
"products": [
|
136 |
+
{{"name": "Orange Juice", "price": 2.15, "quantity": 1}},
|
137 |
+
{{"name": "Apples", "price": 3.50, "quantity": 1}}
|
138 |
+
]
|
139 |
}}
|
140 |
|
141 |
Text:
|
|
|
158 |
)
|
159 |
return response.choices[0].message.content
|
160 |
|
161 |
+
def convert_floats_to_decimal(obj):
|
162 |
+
if isinstance(obj, float):
|
163 |
+
return Decimal(str(obj))
|
164 |
+
elif isinstance(obj, dict):
|
165 |
+
return {k: convert_floats_to_decimal(v) for k, v in obj.items()}
|
166 |
+
elif isinstance(obj, list):
|
167 |
+
return [convert_floats_to_decimal(i) for i in obj]
|
168 |
+
else:
|
169 |
+
return obj
|
170 |
+
|
171 |
def save_to_dynamodb(data, table_name="Receipts"):
|
172 |
+
dynamodb = boto3.resource('dynamodb')
|
173 |
+
table = dynamodb.Table(table_name)
|
174 |
+
|
175 |
+
# Calculate total amount if not provided
|
176 |
+
if "products" in data and not data.get("amount_paid"):
|
177 |
+
total = sum(product["total"] for product in data["products"])
|
178 |
+
data["amount_paid"] = total
|
179 |
+
|
180 |
+
# Convert all float values to Decimal for DynamoDB
|
181 |
+
data = convert_floats_to_decimal(data)
|
182 |
+
|
183 |
+
# Generate receipt number if not present
|
184 |
+
if not data.get("receipt_no"):
|
185 |
+
data["receipt_no"] = str(uuid.uuid4())
|
186 |
+
|
187 |
table.put_item(Item=data)
|
188 |
|
189 |
def merge_extractions(regex_fields, llm_fields):
|
|
|
351 |
else:
|
352 |
st.info("Confusion matrix not found.")
|
353 |
|
354 |
+
# Load model and processor
|
355 |
+
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
|
356 |
+
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
|
357 |
+
|
358 |
+
# Load your image (crop to handwritten region if possible)
|
359 |
+
image = Image.open('handwritten_sample.jpg').convert("RGB")
|
360 |
+
|
361 |
+
# Preprocess and predict
|
362 |
+
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
363 |
+
generated_ids = model.generate(pixel_values)
|
364 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
365 |
+
|
366 |
+
print("Handwritten text:", generated_text)
|
367 |
+
|
368 |
if __name__ == "__main__":
|
369 |
main()
|